Merge pull request #23886 from wdirons:add_pythonpath_to_bazelrc_if_referenced
PiperOrigin-RevId: 223204341
diff --git a/.github/ISSUE_TEMPLATE/40-tflite-op-request.md b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md
new file mode 100644
index 0000000..7b39127
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/40-tflite-op-request.md
@@ -0,0 +1,24 @@
+---
+name: TensorFlow Lite Op Request
+about: Use this template for reporting ops you are using or missing.
+
+---
+
+
+**System information**
+- OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
+- TensorFlow installed from (source or binary):
+- TensorFlow version (or github SHA if from source):
+
+
+**Provide the text output from tflite_convert**
+
+```
+# Copy and paste here
+```
+
+Also, please include a link to a GraphDef or the model if possible.
+
+**Any other info / logs**
+
+Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
diff --git a/README.md b/README.md
index 8af5370..902753e 100644
--- a/README.md
+++ b/README.md
@@ -9,12 +9,14 @@
|-----------------|
| [](https://www.tensorflow.org/api_docs/) |
-**TensorFlow** is an open source software library for numerical computation using
-data flow graphs. The graph nodes represent mathematical operations, while
+**TensorFlow** is an open source software library for numerical computation
+using data flow graphs. The graph nodes represent mathematical operations, while
the graph edges represent the multidimensional data arrays (tensors) that flow
-between them. This flexible architecture enables you to deploy computation to one
-or more CPUs or GPUs in a desktop, server, or mobile device without rewriting
-code. TensorFlow also includes [TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard), a data visualization toolkit.
+between them. This flexible architecture enables you to deploy computation to
+one or more CPUs or GPUs in a desktop, server, or mobile device without
+rewriting code. TensorFlow also includes
+[TensorBoard](https://github.com/tensorflow/tensorboard), a data visualization
+toolkit.
TensorFlow was originally developed by researchers and engineers
working on the Google Brain team within Google's Machine Intelligence Research
@@ -118,15 +120,17 @@
**Linux CPU with Intel® MKL-DNN** Python 2.7<br> **Linux CPU with Intel® MKL-DNN** Python 3.4<br> **Linux CPU with Intel® MKL-DNN** Python 3.5<br> **Linux CPU with Intel® MKL-DNN** Python 3.6 | [](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.11.0 py2.7](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp27-cp27mu-linux_x86_64.whl)<br>[1.11.0 py3.4](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp34-cp34m-linux_x86_64.whl)<br>[1.11.0 py3.5](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp35-cp35m-linux_x86_64.whl)<br>[1.11.0 py3.6](https://storage.googleapis.com/intel-optimized-tensorflow/tensorflow-1.11.0-cp36-cp36m-linux_x86_64.whl)
## For more information
-* [TensorFlow Website](https://www.tensorflow.org)
-* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
-* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
-* [TensorFlow Twitter](https://twitter.com/tensorflow)
-* [TensorFlow Blog](https://medium.com/tensorflow)
-* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
-* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
-* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
-* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
+
+* [TensorFlow Website](https://www.tensorflow.org)
+* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/)
+* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
+* [TensorFlow Twitter](https://twitter.com/tensorflow)
+* [TensorFlow Blog](https://medium.com/tensorflow)
+* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si)
+* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap)
+* [TensorFlow White Papers](https://www.tensorflow.org/about/bib)
+* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ)
+* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard)
Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate.
diff --git a/configure.py b/configure.py
index 17ab7a0..f087da0 100644
--- a/configure.py
+++ b/configure.py
@@ -243,7 +243,7 @@
if environ_cp.get('PYTHONPATH'):
python_paths = environ_cp.get('PYTHONPATH').split(':')
if python_lib_path in python_paths:
- write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH'))
+ write_action_env_to_bazelrc('PYTHONPATH', environ_cp.get('PYTHONPATH'))
# Write tools/python_bin_path.sh
with open(
@@ -866,7 +866,7 @@
cuda_toolkit_paths_full = [
os.path.join(cuda_toolkit_path, x) for x in cuda_rt_lib_paths
]
- if any([os.path.exists(x) for x in cuda_toolkit_paths_full]):
+ if any(os.path.exists(x) for x in cuda_toolkit_paths_full):
break
# Reset and retry
@@ -1701,6 +1701,7 @@
config_info_line('nohdfs', 'Disable HDFS support.')
config_info_line('noignite', 'Disable Apacha Ignite support.')
config_info_line('nokafka', 'Disable Apache Kafka support.')
+ config_info_line('nonccl', 'Disable NVIDIA NCCL support.')
if __name__ == '__main__':
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 17577af..fd4b942 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -246,6 +246,12 @@
visibility = ["//visibility:public"],
)
+config_setting(
+ name = "no_nccl_support",
+ define_values = {"no_nccl_support": "true"},
+ visibility = ["//visibility:public"],
+)
+
# Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements.
config_setting(
diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py
index 2efb884..f13623b 100644
--- a/tensorflow/api_template.__init__.py
+++ b/tensorflow/api_template.__init__.py
@@ -21,8 +21,6 @@
import os as _os
# pylint: disable=g-bad-import-order
-from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
-
from tensorflow.python.tools import component_api_helper as _component_api_helper
_component_api_helper.package_hook(
parent_package_str=__name__,
@@ -30,8 +28,6 @@
# API IMPORTS PLACEHOLDER
-from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top
-
# Make sure directory containing top level submodules is in
# the __path__ so that "from tensorflow.foo import bar" works.
# We're using bitwise, but there's nothing special about that.
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 84238ff..f653e58 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -121,6 +121,7 @@
":c_api",
":c_api_internal",
"//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_internal",
"//tensorflow/compiler/jit:flags",
"//tensorflow/contrib/tpu:all_ops",
"//tensorflow/core:core_cpu",
@@ -263,7 +264,7 @@
tf_cc_test(
name = "c_api_experimental_test",
- size = "small",
+ size = "medium",
srcs = ["c_api_experimental_test.cc"],
data = ["testdata/tf_record"],
linkopts = select({
@@ -274,8 +275,11 @@
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
deps = [
+ ":c_api",
":c_api_experimental",
":c_test_util",
+ "//tensorflow/c/eager:c_api",
+ "//tensorflow/c/eager:c_api_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index f160f20..69de4cb 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -15,7 +15,10 @@
#include "tensorflow/c/c_api_experimental.h"
+#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
@@ -23,6 +26,7 @@
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/net.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@@ -8740,8 +8744,55 @@
TF_DeleteStatus(status);
}
-TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
- const char* errMsg) {
+struct TFE_ExecuteOpNotification {
+ TFE_ExecuteOpNotification() : status(TF_NewStatus(), TF_DeleteStatus) {}
+ tensorflow::Notification n;
+ std::unique_ptr<tensorflow::Thread> thread;
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status;
+};
+
+TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(TFE_Op* op,
+ TFE_TensorHandle** retvals,
+ int* num_retvals,
+ TF_Status* status) {
+ TFE_ExecuteOpNotification* n = new TFE_ExecuteOpNotification;
+
+ n->thread.reset(op->operation.EagerContext()->TFEnv()->StartThread(
+ tensorflow::ThreadOptions(), "ExecuteOpThread",
+ [op, retvals, num_retvals, n]() {
+ TFE_Execute(op, retvals, num_retvals, n->status.get());
+ n->n.Notify();
+ }));
+
+ return n;
+}
+
+void TFE_ExecuteOpNotificationWaitAndDelete(
+ TFE_ExecuteOpNotification* notification, TF_Status* status) {
+ if (notification == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Passed in notification is a nullptr.");
+
+ return;
+ }
+ if (notification->thread == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Passed in notification didn't start a thread correctly. Cleaning up "
+ "this notification. Please re-execute the operation to get a new "
+ "notification.");
+
+ delete notification;
+ return;
+ }
+
+ notification->n.WaitForNotification();
+
+ status->status = notification->status->status;
+
+ delete notification;
+}
+
+void TF_MakeInternalErrorStatus(TF_Status* status, const char* errMsg) {
status->status = tensorflow::errors::Internal(errMsg);
}
@@ -8815,3 +8866,7 @@
void TF_InitMain(const char* usage, int* argc, char*** argv) {
tensorflow::port::InitMain(usage, argc, argv);
}
+
+int TF_PickUnusedPortOrDie() {
+ return tensorflow::internal::PickUnusedPortOrDie();
+}
diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h
index 25c03df..c04cd44 100644
--- a/tensorflow/c/c_api_experimental.h
+++ b/tensorflow/c/c_api_experimental.h
@@ -180,6 +180,25 @@
TF_CAPI_EXPORT extern void TFE_TensorHandlePrintDebugString(
TFE_TensorHandle* handle);
+typedef struct TFE_ExecuteOpNotification TFE_ExecuteOpNotification;
+
+// Allows invoking a kernel asynchronously, and explicitly returns a
+// notification that can be waited upon. This always executes the kernel in a
+// new thread.
+// 1. `retvals` and `num_retvals` can only be consumed after
+// `TFE_ExecuteOp` returns successfully. They shouldn't be used
+// if the return is unsuccessful
+// 2. These new APIs cannot be used together with the TFE context level async
+// support.
+TF_CAPI_EXPORT extern TFE_ExecuteOpNotification* TFE_ExecuteOpInNewThread(
+ TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
+ TF_Status* status);
+
+// Waits to complete the op execution, and cleans up the notification.
+// Errors reported by op execution are set in `status`.
+TF_CAPI_EXPORT extern void TFE_ExecuteOpNotificationWaitAndDelete(
+ TFE_ExecuteOpNotification* notification, TF_Status* status);
+
TF_CAPI_EXPORT extern void TF_MakeInternalErrorStatus(TF_Status* status,
const char* errMsg);
@@ -218,6 +237,10 @@
// this to be called.
TF_CAPI_EXPORT void TF_InitMain(const char* usage, int* argc, char*** argv);
+// Platform-specific implementation to return an unused port. (This should used
+// in tests only.)
+TF_CAPI_EXPORT int TF_PickUnusedPortOrDie();
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index 881dbaf..daa7701 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -15,6 +15,8 @@
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_test_util.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -173,5 +175,126 @@
EXPECT_EQ(id, 0);
}
+TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Simple) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+
+ TFE_Op* matmul_op = MatMulOp(ctx, m, m);
+
+ TFE_TensorHandle* retvals[1] = {nullptr};
+ int num_retvals = 1;
+
+ auto* r =
+ TFE_ExecuteOpInNewThread(matmul_op, &retvals[0], &num_retvals, status);
+
+ TFE_ExecuteOpNotificationWaitAndDelete(r, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(7, product[0]);
+ EXPECT_EQ(10, product[1]);
+ EXPECT_EQ(15, product[2]);
+ EXPECT_EQ(22, product[3]);
+
+ TFE_DeleteOp(matmul_op);
+ TFE_DeleteTensorHandle(m);
+
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
+
+// Perform a send/recv test. Recv blocks, so they need to be executed
+// asynchronously.
+TEST(CAPI_EXPERIMENTAL, TFE_ExecuteOpInNewThreadTest_Blocking) {
+ TF_Status* status = TF_NewStatus();
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_Context* ctx = TFE_NewContext(opts, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_DeleteContextOptions(opts);
+
+ // Returns a 2x2 float32 Tensor on the CPU, with data 1., 2., 3., 4.
+ TFE_TensorHandle* m = TestMatrixTensorHandle();
+
+ // Build a send op.
+ TFE_Op* send_op = TFE_NewOp(ctx, "_Send", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(send_op, m, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ string tensor_name = "Tensor";
+ TFE_OpSetAttrType(send_op, "T", TF_FLOAT);
+ TFE_OpSetAttrString(send_op, "tensor_name", tensor_name.c_str(),
+ tensor_name.size());
+ string send_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+ TFE_OpSetAttrString(send_op, "send_device", send_device.c_str(),
+ send_device.size());
+ TFE_OpSetAttrInt(send_op, "send_device_incarnation", 1234);
+ string recv_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+ TFE_OpSetAttrString(send_op, "recv_device", recv_device.c_str(),
+ recv_device.size());
+ TFE_OpSetAttrBool(send_op, "client_terminated", true);
+
+ // Build a recv op.
+ TFE_Op* recv_op = TFE_NewOp(ctx, "_Recv", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TFE_OpSetAttrType(recv_op, "tensor_type", TF_FLOAT);
+ TFE_OpSetAttrString(recv_op, "tensor_name", tensor_name.c_str(),
+ tensor_name.size());
+ TFE_OpSetAttrString(recv_op, "send_device", send_device.c_str(),
+ send_device.size());
+ TFE_OpSetAttrInt(recv_op, "send_device_incarnation", 1234);
+ TFE_OpSetAttrString(recv_op, "recv_device", recv_device.c_str(),
+ recv_device.size());
+ TFE_OpSetAttrBool(recv_op, "client_terminated", true);
+
+ TFE_TensorHandle* send_retvals;
+ int send_num_retvals = 0;
+ auto* send_result = TFE_ExecuteOpInNewThread(send_op, &send_retvals,
+ &send_num_retvals, status);
+
+ TFE_TensorHandle* recv_retvals[1] = {nullptr};
+ int recv_num_retvals = 1;
+ auto* recv_result = TFE_ExecuteOpInNewThread(recv_op, &recv_retvals[0],
+ &recv_num_retvals, status);
+
+ TFE_ExecuteOpNotificationWaitAndDelete(send_result, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_ExecuteOpNotificationWaitAndDelete(recv_result, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ TF_Tensor* t = TFE_TensorHandleResolve(recv_retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+
+ float product[4] = {0};
+ EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
+ memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
+ TF_DeleteTensor(t);
+ EXPECT_EQ(1, product[0]);
+ EXPECT_EQ(2, product[1]);
+ EXPECT_EQ(3, product[2]);
+ EXPECT_EQ(4, product[3]);
+
+ TFE_DeleteOp(send_op);
+ TFE_DeleteOp(recv_op);
+ TFE_DeleteTensorHandle(m);
+
+ TFE_DeleteTensorHandle(recv_retvals[0]);
+ TFE_DeleteContext(ctx);
+ TF_DeleteStatus(status);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index ba3d853..c34a84f 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -50,6 +50,7 @@
],
"//conditions:default": [],
}) + [
+ "@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
@@ -143,6 +144,7 @@
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 1920449..027d752 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -21,6 +21,7 @@
#include <string>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api_internal.h"
@@ -80,7 +81,7 @@
const std::vector<string>& remote_workers,
tensorflow::WorkerCacheInterface* worker_cache,
std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
- std::vector<tensorflow::Device*> remote_devices;
+ std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
tensorflow::Status status;
// TODO(nareshmodi) do this in parallel instead of serially.
for (const string& remote_worker : remote_workers) {
@@ -93,7 +94,7 @@
status = s;
if (s.ok()) {
for (tensorflow::Device* d : *devices) {
- remote_devices.push_back(d);
+ remote_devices.emplace_back(d);
}
}
n.Notify();
@@ -101,7 +102,7 @@
n.WaitForNotification();
}
std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
- new tensorflow::DeviceMgr(remote_devices));
+ new tensorflow::DeviceMgr(std::move(remote_devices)));
TF_RETURN_IF_ERROR(status);
@@ -262,13 +263,13 @@
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<tensorflow::Device>> devices;
status->status = tensorflow::DeviceFactory::AddDevices(
opts->session_options.options, "/job:localhost/replica:0/task:0",
&devices);
if (!status->status.ok()) return nullptr;
std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
- new tensorflow::DeviceMgr(devices));
+ new tensorflow::DeviceMgr(std::move(devices)));
tensorflow::Rendezvous* r =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
@@ -410,6 +411,18 @@
: d->name().c_str();
}
+const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
+ TF_Status* status) {
+ if (h == nullptr || h->handle == nullptr) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "The passed in handle is a nullptr");
+ return nullptr;
+ }
+ tensorflow::Device* d = h->handle->device();
+ return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
+ : d->name().c_str();
+}
+
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
TFE_TensorHandle* h, TF_Status* status) {
if (h == nullptr || h->handle == nullptr) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index b2454d8..8d6c8d9 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -169,10 +169,33 @@
TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h,
int dim_index,
TF_Status* status);
+
+// Returns the device of the operation that produced `h`.
+// If `h` was produced by a copy, returns the destination device of
+// the copy. Note that returned device name is not always the device
+// holding the tensor handle's memory. If you want the latter, use
+// TFE_TensorHandleBackingDeviceName.
+// This function will block till the operation that produces `h` has completed.
+//
+// Device on which the kernel of the operation that produced `h` ran.
+//
+// If `h` was produced by a copy, returns the destination device of
+// the copy.
+//
+// Note that returned device name is not always the device that owns the memory
+// that backs the tensor handle. For the latter see
+// TFE_TensorHandleBackingDeviceName.
+//
// This function will block till the operation that produces `h` has completed.
TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(
TFE_TensorHandle* h, TF_Status* status);
+// Returns the name of the device in whose memory `h` resides.
+//
+// This function will block till the operation that produces `h` has completed.
+TF_CAPI_EXPORT extern const char* TFE_TensorHandleBackingDeviceName(
+ TFE_TensorHandle* h, TF_Status* status);
+
// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor
// with `h`. On success, `status` is set to OK. On failure, `status` reflects
// the error and a nullptr is returned.
diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc
index 0045bb5..6b39b79 100644
--- a/tensorflow/c/eager/c_api_test.cc
+++ b/tensorflow/c/eager/c_api_test.cc
@@ -16,6 +16,7 @@
#include "tensorflow/c/eager/c_api.h"
#include <string.h>
+#include "absl/strings/match.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -794,6 +795,14 @@
TF_SetStatus(status.get(), TF_OK, "");
+ device_name = TFE_TensorHandleBackingDeviceName(h, status.get());
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
+ ASSERT_EQ(device_name, nullptr);
+ ASSERT_EQ("The passed in handle is a nullptr",
+ string(TF_Message(status.get())));
+
+ TF_SetStatus(status.get(), TF_OK, "");
+
int num_dims = TFE_TensorHandleNumDims(h, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
ASSERT_EQ(num_dims, -1);
@@ -809,6 +818,62 @@
string(TF_Message(status.get())));
}
+TEST(CAPI, TensorHandleDevices) {
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+ TF_NewStatus(), TF_DeleteStatus);
+ TFE_ContextOptions* opts = TFE_NewContextOptions();
+ TFE_Context* ctx = TFE_NewContext(opts, status.get());
+ TFE_DeleteContextOptions(opts);
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+
+ TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+ const char* device_name = TFE_TensorHandleDeviceName(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(device_name, "CPU:0")) << device_name;
+ const char* backing_device_name =
+ TFE_TensorHandleBackingDeviceName(hcpu, status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
+ << backing_device_name;
+
+ // Disable the test if no GPU is present.
+ string gpu_device_name;
+ if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
+ TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
+ hcpu, ctx, gpu_device_name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ TFE_Op* shape_op = ShapeOp(ctx, hgpu);
+ TFE_OpSetDevice(shape_op, gpu_device_name.c_str(), status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+ TFE_TensorHandle* retvals[1];
+ int num_retvals = 1;
+ TFE_Execute(shape_op, &retvals[0], &num_retvals, status.get());
+ ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+
+ // .device of shape is GPU since the op is executed on GPU
+ device_name = TFE_TensorHandleDeviceName(retvals[0], status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(device_name, "GPU:0")) << device_name;
+
+ // .backing_device of shape is CPU since the tensor is backed by CPU
+ backing_device_name =
+ TFE_TensorHandleBackingDeviceName(retvals[0], status.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ ASSERT_TRUE(absl::StrContains(backing_device_name, "CPU:0"))
+ << backing_device_name;
+
+ TFE_DeleteOp(shape_op);
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteTensorHandle(hgpu);
+ }
+
+ TFE_DeleteTensorHandle(hcpu);
+ TFE_ContextAsyncWait(ctx, status.get());
+ EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
+ TFE_DeleteContext(ctx);
+}
+
void Execute_MatMul_CPU(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc
index 008f088..bd38127 100644
--- a/tensorflow/c/eager/c_api_test_util.cc
+++ b/tensorflow/c/eager/c_api_test_util.cc
@@ -104,6 +104,19 @@
return op;
}
+TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) {
+ TF_Status* status = TF_NewStatus();
+
+ TFE_Op* op = TFE_NewOp(ctx, "Shape", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TFE_OpAddInput(op, a, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
+
+ return op;
+}
+
TFE_TensorHandle* TestAxisTensorHandle() {
int64_t dims[] = {1};
int data[] = {1};
diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h
index 474cae6..75ef945 100644
--- a/tensorflow/c/eager/c_api_test_util.h
+++ b/tensorflow/c/eager/c_api_test_util.h
@@ -37,6 +37,9 @@
// Return a matmul op multiplying `a` by `b`.
TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b);
+// Return a shape op fetching the shape of `a`.
+TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a);
+
// Return an 1-D INT32 tensor containing a single value 1.
TFE_TensorHandle* TestAxisTensorHandle();
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index 83353b7..a09becc 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -489,6 +489,7 @@
"image_ops",
"io_ops",
"linalg_ops",
+ "list_ops",
"logging_ops",
"lookup_ops",
"manip_ops",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index b17bc65..697599f 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -164,7 +164,8 @@
}
// Generate methods for args (inputs).
-Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
+Status GenArgMethods(const tf2xla::Config& config,
+ const xla::ProgramShapeProto& ps,
const CompileResult& compile_result, string* methods) {
size_t num_args = ps.parameters_size();
if (config.feed_size() != num_args) {
@@ -204,7 +205,7 @@
// Generate methods for results (outputs).
Status GenResultMethods(const tf2xla::Config& config,
- const xla::ProgramShape& ps, string* methods) {
+ const xla::ProgramShapeProto& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
// The XlaCompiler we use to build the xla computation always generates a
// tuple result, and we rely on this to simplify code generation.
@@ -336,7 +337,7 @@
ExtractEntryParamBufferInfos(buffer_infos);
std::vector<BufferInfo> buffer_infos_for_temps =
ExtractTempBufferInfos(buffer_infos);
- const xla::ProgramShape& ps = compile_result.program_shape;
+ const xla::ProgramShapeProto& ps = compile_result.program_shape;
string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
@@ -548,8 +549,8 @@
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
// Shape of the args and results.
- static const xla::ProgramShape* StaticProgramShape() {
- static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
+ static const xla::ProgramShapeProto* StaticProgramShape() {
+ static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
return kShape;
}
@@ -615,11 +616,11 @@
Status GenerateMetadata(const CodegenOpts& opts,
const CompileResult& compile_result,
MetadataResult* metadata_result) {
- std::unique_ptr<xla::ProgramShape> program_shape;
+ std::unique_ptr<xla::ProgramShapeProto> program_shape;
if (opts.gen_program_shape) {
program_shape =
- absl::make_unique<xla::ProgramShape>(compile_result.program_shape);
+ absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
// The parameter names are currently meaningless, and redundant with the
// rest of our metadata, so clear them out to avoid confusion and save
@@ -631,8 +632,8 @@
// a shim that evaluates to nullptr, which is what we want.
ProtobufToEmbed program_shape_protobuf{
- CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
- program_shape.get()};
+ CreateUniqueIdentifier(opts, "ProgramShapeProto"),
+ "xla::ProgramShapeProto", program_shape.get()};
ProtobufToEmbed hlo_profile_printer_data_protobuf{
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 90410c4..9485e86 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -57,7 +57,7 @@
std::vector<string> header_variable_decls;
// program_shape_access_shim is a C++ expression that constructs the
- // xla::ProgramShape instance for the CompileResult passed to
+ // xla::ProgramShapeProto instance for the CompileResult passed to
// GenerateMetadata.
string program_shape_access_shim;
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index bb288d2..c1788ca 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -181,13 +181,15 @@
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
5, {}));
- compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
- {
- xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
- xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
- },
- xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
+ compile_result.program_shape =
+ xla::ShapeUtil::MakeProgramShape(
+ {
+ xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
+ xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
+ },
+ xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}))
+ .ToProto();
compile_result.entry_point = "entry_point";
compile_result.pointer_size = 8;
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index e4d8a02..a2cdab5 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -22,7 +22,7 @@
void* result, const xla::ExecutableRunOptions* run_options,
const void** args, void** temps, tensorflow::int64* profile_counters);
-extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
+extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[];
namespace foo {
@@ -253,10 +253,10 @@
}
// Shape of the args and results.
- static const xla::ProgramShape* StaticProgramShape() {
- static const xla::ProgramShape* kShape = []() {
- xla::ProgramShape* proto = new xla::ProgramShape;
- proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52);
+ static const xla::ProgramShapeProto* StaticProgramShape() {
+ static const xla::ProgramShapeProto* kShape = []() {
+ xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
+ proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52);
return proto;
}();
return kShape;
diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden
index eb001c5..ce8e5ec 100644
--- a/tensorflow/compiler/aot/codegen_test_o.golden
+++ b/tensorflow/compiler/aot/codegen_test_o.golden
Binary files differ
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 2b5f97b..3bc99ef 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -56,8 +56,8 @@
return errors::Unknown("Couldn't get XLA program shape: ",
pshape_or.status().error_message());
}
- compile_result->program_shape = *pshape_or.ValueOrDie();
- xla::ProgramShape* pshape = &compile_result->program_shape;
+ compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
+ xla::ProgramShapeProto* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts;
arg_layouts.reserve(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {
diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h
index e03c5b1..ee7bb26 100644
--- a/tensorflow/compiler/aot/compile.h
+++ b/tensorflow/compiler/aot/compile.h
@@ -33,9 +33,9 @@
struct CompileResult {
// Contains object file and meta-info.
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
- xla::ProgramShape program_shape; // Static shape of args and results.
- string entry_point; // Name of generated function.
- int pointer_size = 0; // Size of a pointer in bytes.
+ xla::ProgramShapeProto program_shape; // Static shape of args and results.
+ string entry_point; // Name of generated function.
+ int pointer_size = 0; // Size of a pointer in bytes.
};
// CompileGraph compiles the graph_def into an object file containing a function
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index f10852c..711feed 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -526,7 +526,7 @@
// muladd has the program shape defined.
MatMulAndAddComp muladd;
- const xla::ProgramShape* muladd_shape = muladd.ProgramShape();
+ const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
ASSERT_TRUE(muladd_shape != nullptr);
ASSERT_EQ(muladd_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 682c0f0..be91ed4 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -23,7 +23,6 @@
load("//tensorflow:tensorflow.bzl", "cc_header_only_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
-load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -38,7 +37,7 @@
":xla_cpu_device",
":xla_cpu_jit",
"//tensorflow/compiler/plugin",
- ] + if_cuda_is_configured([
+ ] + if_cuda([
":xla_gpu_device",
":xla_gpu_jit",
]),
@@ -51,6 +50,7 @@
deps = [
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:cpu_plugin",
@@ -268,6 +268,7 @@
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
@@ -736,7 +737,10 @@
visibility = [
":friends",
],
- deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"],
+ deps = [
+ "//tensorflow/compiler/jit/ops:xla_ops_grad",
+ "//tensorflow/compiler/jit/ops:xla_ops_wrapper_py",
+ ],
)
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index 11df946..48a23a4 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -42,14 +42,8 @@
.ok());
}
- void TearDown() override {
- for (Device* device : devices_) {
- delete device;
- }
- }
-
private:
- std::vector<Device*> devices_;
+ std::vector<std::unique_ptr<Device>> devices_;
};
using ::tensorflow::testing::FindNodeByName;
diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
index 7386660..0f872a4 100644
--- a/tensorflow/compiler/jit/create_xla_launch_op_test.cc
+++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc
@@ -59,8 +59,9 @@
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1});
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices_));
+ options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) {
@@ -69,7 +70,7 @@
lib_def_ = absl::make_unique<FunctionLibraryDefinition>(
OpRegistry::Global(), proto);
OptimizerOptions opts;
- device_mgr_ = absl::make_unique<DeviceMgr>(devices_);
+ device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr);
@@ -77,7 +78,6 @@
}
FunctionLibraryRuntime* flr_;
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc
index 28ec37b..bcc3213 100644
--- a/tensorflow/compiler/jit/encapsulate_util.cc
+++ b/tensorflow/compiler/jit/encapsulate_util.cc
@@ -86,7 +86,7 @@
continue;
} else if (src_xla_computation && !dst_xla_computation) {
if (src_outside_compilation) {
- // Case 1d: outside compilation to host computation control edge.
+ // Case 1c: outside compilation to host computation control edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
@@ -94,7 +94,7 @@
}
} else if (!src_xla_computation && dst_xla_computation) {
if (dst_outside_compilation) {
- // Case 1d: host computation control to outside compilation edge.
+ // Case 1c: host computation control to outside compilation edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
@@ -103,40 +103,24 @@
} else { // src_xla_computation && dst_xla_computation
if (*src_xla_computation != *dst_xla_computation) {
if (src_outside_compilation && dst_outside_compilation) {
- // Case 1c: outside compilation to outside compilation control edge.
+ // Case 1b: outside compilation to outside compilation control edge.
edges_to_remove.push_back(e);
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
} else if (src_outside_compilation && !dst_outside_compilation) {
- // Case 1b: outside compilation to another XLA computaition control
+ // Case 1a: outside compilation to another XLA computaition control
// edge.
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->src(), kXlaConnectedToOtherXlaComputationAttrName,
*dst_xla_computation));
} else if (!src_outside_compilation && dst_outside_compilation) {
- // Case 1b: another XLA computaition to outside compilation control
+ // Case 1a: another XLA computaition to outside compilation control
// edge.
TF_RETURN_IF_ERROR(AppendToListAttr<string>(
e->dst(), kXlaConnectedFromOtherXlaComputationAttrName,
*src_xla_computation));
}
- } else { // *src_xla_computation == *dst_xla_computation
- if (src_outside_compilation && dst_outside_compilation) {
- if (*src_outside_compilation != *dst_outside_compilation) {
- // Case 1c: outside compilation to outside compilation control edge.
- edges_to_remove.push_back(e);
-
- TF_RETURN_IF_ERROR(AppendToListAttr<string>(
- e->dst(), kXlaControlDependenciesAttrName, e->src()->name()));
- }
- } else if (src_outside_compilation && !dst_outside_compilation) {
- // Case 1a: outside compilation to its XLA computation control edge.
- ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
- } else if (!src_outside_compilation && dst_outside_compilation) {
- // Case 1a: XLA computation to outside compilation in it control edge.
- ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
- }
}
}
}
@@ -181,12 +165,6 @@
edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
}
- } else { // *src_xla_computation == *dst_xla_computation
- if (src_outside_compilation && dst_outside_compilation &&
- *src_outside_compilation != *dst_outside_compilation) {
- edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
- VLOG(4) << "XLA -> XLA edge: " << e->DebugString();
- }
}
}
@@ -594,14 +572,242 @@
return Status::OK();
}
+// Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PreprocessEdgesBetweenOutsideCompilations` for details.
+Status PreprocessControlEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather edges to remove. We should not remove the edge while iterating.
+ std::vector<const Edge*> edges_to_remove;
+ for (const Edge* e : g->edges()) {
+ if (!e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+
+ if (src_outside_compilation && dst_outside_compilation) {
+ if (*src_outside_compilation != *dst_outside_compilation) {
+ // Case 1a: outside compilation to outside compilation control edge.
+ edges_to_remove.push_back(e);
+
+ TF_RETURN_IF_ERROR(AppendToListAttr<string>(
+ e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
+ e->src()->name()));
+ }
+ } else if (src_outside_compilation && !dst_outside_compilation) {
+ // Case 1b: outside compilation to its XLA computation control edge.
+ ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
+ } else if (!src_outside_compilation && dst_outside_compilation) {
+ // Case 1b: XLA computation to outside compilation in it control edge.
+ ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
+ }
+ }
+
+ for (auto e : edges_to_remove) {
+ g->RemoveEdge(e);
+ }
+ return Status::OK();
+}
+
+// Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PreprocessEdgesBetweenOutsideCompilations` for details.
+Status PreprocessDataEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather edges between outside compilation and host computation. Notice that
+ // we do not store `Edge*` directly because we remove some nodes while adding
+ // Identity nodes, and those Edge pointers might be invalidated.
+ struct EdgeInfo {
+ int dst_input, dst_node_id;
+ };
+ std::vector<EdgeInfo> edges;
+ for (const Edge* e : g->edges()) {
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ auto src_outside_compilation =
+ GetStringAttr(*e->src(), outside_compilation_attr_name);
+ auto dst_outside_compilation =
+ GetStringAttr(*e->dst(), outside_compilation_attr_name);
+
+ if (src_outside_compilation && dst_outside_compilation &&
+ *src_outside_compilation != *dst_outside_compilation) {
+ edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
+ VLOG(4) << "Oc -> oc edge: " << e->DebugString();
+ }
+ }
+
+ // Remove the edge from host to outside compilation. Add a placeholder as
+ // outside compilation node input.
+ std::map<string, Node*> placeholders;
+ for (int i = 0; i < edges.size(); i++) {
+ Node* dst = g->FindNodeId(edges[i].dst_node_id);
+ const Edge* e;
+ TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
+ Node* src = e->src();
+ int src_output = e->src_output(), dst_input = e->dst_input();
+ g->RemoveEdge(e);
+
+ // Find or create placeholder node.
+ string new_name = absl::StrCat(src->name(), "_oc_to_oc_placeholder");
+ auto iter = placeholders.find(new_name);
+ Node* placeholder_node;
+ if (iter == placeholders.end()) {
+ NodeDefBuilder placeholder_builder(new_name, "Placeholder");
+ placeholder_builder.Attr("dtype", src->output_type(src_output));
+ string outside_compilation_attr;
+ TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
+ outside_compilation_attr_name,
+ &outside_compilation_attr));
+ placeholder_builder.Attr(outside_compilation_attr_name,
+ outside_compilation_attr);
+ placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
+ src->name());
+ placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
+ src_output);
+ NodeDef placeholder_def;
+ TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
+ Status s;
+ placeholder_node = g->AddNode(placeholder_def, &s);
+ TF_RETURN_IF_ERROR(s);
+ placeholders[new_name] = placeholder_node;
+ } else {
+ placeholder_node = iter->second;
+ }
+ g->AddEdge(placeholder_node, 0, dst, dst_input);
+
+ // Replace `e->dst()` because its input node changed.
+ NodeDef new_def = dst->def();
+ *new_def.mutable_input(dst_input) = placeholder_node->name();
+ TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
+
+ // Other edge in `edges` might have `e->dst()` as src or dst
+ // node. Before removing `e->dst()`, replace those edges with
+ // corresponding edges for `dst_replace_node`.
+ for (int j = i + 1; j < edges.size(); j++) {
+ if (edges[j].dst_node_id == edges[i].dst_node_id) {
+ edges[j].dst_node_id = dst_replace_node->id();
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PostprocessEdgesBetweenOutsideCompilations` for details.
+Status PostprocessDataEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Gather all outside compilation to outside compilation nodes.
+ std::vector<Node*> placeholder_nodes;
+ for (Node* n : g->nodes()) {
+ if (n->type_string() == "Placeholder" &&
+ HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
+ placeholder_nodes.push_back(n);
+ }
+ }
+
+ // Remove the placeholder nodes, and reconnect original edge.
+ auto node_name_index = g->BuildNodeNameIndex();
+ for (auto n : placeholder_nodes) {
+ string node_name;
+ int node_src_output;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
+ auto iter = node_name_index.find(node_name);
+ if (iter == node_name_index.end()) {
+ return errors::Internal(
+ "Cannot find original node for oc -> host placeholder node ",
+ node_name);
+ }
+
+ // Change all usage node to use the original node instead.
+ Node* original_node = iter->second;
+ std::vector<const Edge*> control_edges;
+ std::vector<OutEdgeInfo> data_edges;
+ for (auto e : n->out_edges()) {
+ if (e->IsControlEdge()) {
+ control_edges.push_back(e);
+ } else {
+ data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
+ }
+ }
+ for (const Edge* e : control_edges) {
+ g->AddControlEdge(original_node, e->dst());
+ g->RemoveEdge(e);
+ }
+ for (int i = 0; i < data_edges.size(); i++) {
+ Node* dst = data_edges[i].dst;
+ NodeDef new_def = dst->def();
+ int dst_input = data_edges[i].dst_input;
+ *new_def.mutable_input(dst_input) =
+ absl::StrCat(original_node->name(), ":", node_src_output);
+ TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
+
+ const Edge* edge_to_replace = nullptr;
+ TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
+ g->RemoveEdge(edge_to_replace);
+ g->AddEdge(original_node, node_src_output, replace_node, dst_input);
+
+ // Other edges might have `dst` as dst node. Update those edges with
+ // `replace_node`.
+ for (int j = i + 1; j < data_edges.size(); j++) {
+ if (data_edges[j].dst == dst) {
+ data_edges[j].dst = replace_node;
+ }
+ }
+
+ // Other placeholder node might have `dst` as original node. Update
+ // `node_name_index` with `replace_node`.
+ node_name_index[replace_node->name()] = replace_node;
+ }
+
+ // Remove placeholder node.
+ g->RemoveNode(n);
+ }
+ return Status::OK();
+}
+
+// Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
+// `PostprocessEdgesBetweenOutsideCompilations` for details.
+Status PostprocessControlEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ auto node_name_index = g->BuildNodeNameIndex();
+
+ // Reconnect outside compilation to outside compilation control edge.
+ for (Node* n : g->nodes()) {
+ std::vector<string> control_deps;
+ Status s =
+ GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
+ &control_deps);
+ if (!s.ok()) {
+ if (s.code() != error::NOT_FOUND) {
+ return s;
+ } else {
+ continue;
+ }
+ } else {
+ n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
+ for (const string& control_input : control_deps) {
+ auto iter = node_name_index.find(control_input);
+ if (iter == node_name_index.end()) {
+ return errors::Internal("Cannot find original node for ",
+ control_input);
+ }
+ g->AddControlEdge(iter->second, n);
+ }
+ }
+ }
+ return Status::OK();
+}
} // namespace
const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
-const char kXlaConnectedToXlaComputationAttrName[] =
- "_xla_connected_to_xla_computation";
-const char kXlaConnectedFromXlaComputationAttrName[] =
- "_xla_connected_from_xla_computation";
const char kXlaConnectedToOtherXlaComputationAttrName[] =
"_xla_connected_to_other_xla_computation";
const char kXlaConnectedFromOtherXlaComputationAttrName[] =
@@ -616,6 +822,15 @@
"_xla_host_to_oc_node_name";
const char kHostToOutsideCompilationSrcOutputAttrName[] =
"_xla_host_to_oc_src_output";
+const char kXlaConnectedToXlaComputationAttrName[] =
+ "_xla_connected_to_xla_computation";
+const char kXlaConnectedFromXlaComputationAttrName[] =
+ "_xla_connected_from_xla_computation";
+const char kOutsideCompilationOriginalNodeAttrName[] =
+ "_xla_oc_to_oc_node_name";
+const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
+const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
+ "_xla_control_dependencies_within_xla_cluster";
Status PerformStaticShapeInferenceBeforeEncapsulation(
Graph* g, const string& xla_computation_attr_name,
@@ -699,4 +914,39 @@
return Status::OK();
}
+Status PreprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ // Remove edges from source node to outside compilation nodes, and edges
+ // from outside compilation nodes to sink node.
+ std::vector<const Edge*> edges_to_remove;
+ for (const Edge* e : g->source_node()->out_edges()) {
+ if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
+ edges_to_remove.push_back(e);
+ }
+ }
+ for (const Edge* e : g->sink_node()->in_edges()) {
+ if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
+ edges_to_remove.push_back(e);
+ }
+ }
+ for (auto e : edges_to_remove) {
+ g->RemoveEdge(e);
+ }
+
+ TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ return Status::OK();
+}
+
+Status PostprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name) {
+ TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
+ g, outside_compilation_attr_name));
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h
index 5e0c4bf..e363bc5 100644
--- a/tensorflow/compiler/jit/encapsulate_util.h
+++ b/tensorflow/compiler/jit/encapsulate_util.h
@@ -44,14 +44,6 @@
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name);
-// Attribute indicating that some ops in this node's XLA computation has control
-// dependency on this node. Attribute value will always be "true".
-extern const char kXlaConnectedToXlaComputationAttrName[];
-
-// Attribute indicating that this node has control dependency on some ops in
-// this node's XLA computation. Attribute value will always be "true".
-extern const char kXlaConnectedFromXlaComputationAttrName[];
-
// Attribute indicating that some ops in other XLA computation has control
// dependency on this node. Attribute value will be a list of string (XLA
// computation names).
@@ -81,6 +73,14 @@
// int (src_output for original edge).
extern const char kOutsideCompilationToHostSrcOutputAttrName[];
+// Attribute indicating that some ops in this node's XLA computation has control
+// dependency on this node. Attribute value will always be "true".
+extern const char kXlaConnectedToXlaComputationAttrName[];
+
+// Attribute indicating that this node has control dependency on some ops in
+// this node's XLA computation. Attribute value will always be "true".
+extern const char kXlaConnectedFromXlaComputationAttrName[];
+
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an host node. Attribute value will be string
// (original input node name).
@@ -91,19 +91,31 @@
// for original edge).
extern const char kHostToOutsideCompilationSrcOutputAttrName[];
-// Preprocesses the graph for encapsulation. It will perform the following
-// operations in order:
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an outside compilation node. Attribute value will be
+// string (original input node name).
+extern const char kOutsideCompilationOriginalNodeAttrName[];
+
+// Attribute indicating that this is an Placeholder node added to act as a
+// temporary input node for an outside compilation node. Attribute value will be
+// int (src_output for original edge).
+extern const char kOutsideCompilationSrcOutputAttrName[];
+
+// Attribute indicating that this node has control dependencies on some other
+// nodes within the same XLA cluster. Attribute value will be a list of string
+// (node names).
+extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
+
+// Preprocesses edges between different XLA clusters for encapsulation. It will
+// perform the following operations in order:
//
-// 1a. For control edges between outside compilation and its XLA computation,
-// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
-// outside compilation node.
-// 1b. For control edges between outside compilation and another XLA
+// 1a. For control edges between outside compilation and another XLA
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
// = XLA computation node name" to the outside compilation node.
-// 1c. For control edges between different outside compilations, remove the edge
-// and add attr "kXlaControlDependenciesAttrName = src node name" to dst
-// node.
-// 1d. For control edges between outside compilation and host computation,
+// 1b. For control edges between different outside compilations (in different
+// XLA computations), remove the edge and add attr
+// "kXlaControlDependenciesAttrName = src node name" to dst node.
+// 1c. For control edges between outside compilation and host computation,
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
// name" to dst node.
// 2. For data edges between different XLA computations, if either src or dst
@@ -146,26 +158,53 @@
const std::map<string, int> host_compute_core;
};
-// Postprocesses the graph for encapsulation. This function reverts what
-// `PreprocessForEncapsulation` did. It will perform the following operations in
-// order:
+// Postprocesses edges between different XLA clusters for encapsulation. This
+// function reverts what `PreprocessForEncapsulation` did. It will perform the
+// following operations in order:
//
// 1. Remove Placeholder nodes between outside compilation and host computation
// (created in `PreprocessForEncapsulation` step 3).
// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
-// 3a. Reconnect control edges between different outside compilations (marked by
-// `PreprocessForEncapsulation` step 1c) and control edges between outside
-// compilation and host computation (marked by `PreprocessForEncapsulation`
-// step 1d).
-// 3b. Reconnect control edges between outside compilation and another XLA
-// computation (marked by `PreprocessForEncapsulation` step 1b).
-// Notice that control edges marked by `PreprocessForEncapsulation` step 1a are
-// not handled here. They are handled in `RewriteOutsideCompilationSubgraphFn`.
+// 3a. Reconnect control edges between outside compilation and another XLA
+// computation (marked by `PreprocessForEncapsulation` step 1a).
+// 3b. Reconnect control edges between different outside compilations (marked by
+// `PreprocessForEncapsulation` step 1b).
+// 3c. Reconnect control edges between outside compilation and host computation
+// (marked by `PreprocessForEncapsulation` step 1c).
Status PostprocessForEncapsulation(
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters);
+// Preprocesses edges within the same XLA cluster. It will perform the following
+// operations in order:
+//
+// 0. Remove edges from source node to outside compilation nodes, and edges
+// from outside compilation nodes to sink node.
+// 1a. For edges between different outside compilation clusters, remove the edge
+// and add attr "kXlaControlDependenciesWithinXlaClusterAttrName = src node
+// name" to dst node.
+// 1b. For control edges between outside compilation and its XLA computation,
+// add attr "kXlaConnected{From, To}XlaComputationAttrName = true" to the
+// outside compilation node.
+// 2. For data edges between different outside compilations, remove the edge
+// and create a Placeholder node as dst node's input.
+Status PreprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name);
+
+// Postprocesses edges within the same XLA cluster. This function reverts what
+// `PreprocessEdgesBetweenOutsideCompilations` did. It will perform the
+// following operations in order:
+//
+// 1. Remove Placeholder nodes between different outside compilations (created
+// in `PreprocessEdgesBetweenOutsideCompilations` step 2).
+// 2a. Reconnect control edges between different outside compilations (marked by
+// `PreprocessEdgesBetweenOutsideCompilations` step 1a).
+// Notice that control edges marked by
+// `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here.
+// They are handled in `RewriteOutsideCompilationSubgraphFn`.
+Status PostprocessEdgesBetweenOutsideCompilations(
+ Graph* g, const string& outside_compilation_attr_name);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_UTIL_H_
diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc
index 7255df3..25c32ce 100644
--- a/tensorflow/compiler/jit/encapsulate_util_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_util_test.cc
@@ -107,28 +107,19 @@
identity4_node->AddAttr("_xla", "1");
identity4_node->AddAttr("_oc", "0");
identity5_node->AddAttr("_xla", "1");
- // Case 1a: control edges between outside compilation and its XLA computation.
- g.AddControlEdge(add_node, identity0_node);
- g.AddControlEdge(identity0_node, identity1_node);
- // Case 1b: control edges between outside compilation and another XLA
+ // Case 1a: control edges between outside compilation and another XLA
// computation.
g.AddControlEdge(identity0_node, identity3_node);
g.AddControlEdge(identity1_node, identity4_node);
- // Case 1c: control edges between different outside compilations.
+ // Case 1b: control edges between different outside compilations.
g.AddControlEdge(identity0_node, identity4_node);
- // Case 1d: control edges between outside compilation and host computation.
+ // Case 1c: control edges between outside compilation and host computation.
g.AddControlEdge(const0_node, identity0_node);
g.AddControlEdge(identity0_node, identity2_node);
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
- // Case 1a: add attr "_xla_connected_{from/to}_xla_computation = true" to the
- // outside compilation node.
- EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
- kXlaConnectedFromXlaComputationAttrName));
- EXPECT_TRUE(HasNodeAttr(identity0_node->def(),
- kXlaConnectedToXlaComputationAttrName));
- // Case 1b: add attr "_xla_control_deps_{from/to} = XLA computation node name"
+ // Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
// to the outside compilation node.
std::vector<string> attr;
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
@@ -140,13 +131,13 @@
kXlaConnectedFromOtherXlaComputationAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "0");
- // Case 1c: add attr "_xla_control_deps = src node name" to dst node.
+ // Case 1b: add attr "_xla_control_deps = src node name" to dst node.
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "identity0");
- // Case 1d: add attr "_xla_control_deps = src node name" to dst node.
+ // Case 1c: add attr "_xla_control_deps = src node name" to dst node.
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
kXlaControlDependenciesAttrName, &attr));
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 2ce6fa7..d334100 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -195,8 +195,11 @@
e->dst()->attrs().Find(kXlaClusterAttr) == nullptr &&
e->dst()->type_string() != kXlaClusterOutput) {
return errors::InvalidArgument(
- "Undeclared output of XLA computation. A common cause of this error "
- "is variable initializers that depend on the XLA computation. Edge: ",
+ "Undeclared output of XLA computation. Some common causes of this "
+ "error are: 1) variable initializers that depend on the XLA "
+ "computation; 2) gradient computations that depend on the XLA "
+ "computation, which can be mitigated by moving gradient computations "
+ "inside XLA computation. Offending edge: ",
e->src()->name(), ":", e->src_output(), " -> ", e->dst()->name(), ":",
e->dst_input());
}
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
index 8b3587c..e3c7e2f 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -366,7 +366,7 @@
// replace this node with compilation result node.
// 3) all outside compilation graphs.
Status ConstructHostGraph(
- const string& xla_cluster_name,
+ const string& xla_cluster_name, const string& outside_compilation_attr_name,
const std::vector<string>& outside_compilation_host_graphs,
FunctionLibraryDefinition* fld, std::unique_ptr<Graph>* host_graph) {
host_graph->reset(new Graph(fld));
@@ -476,6 +476,10 @@
host_graph->get(),
std::unordered_set<const Node*>{(*host_graph)->sink_node()});
+ // Postprocess edges between different outside compilations.
+ TF_RETURN_IF_ERROR(PostprocessEdgesBetweenOutsideCompilations(
+ host_graph->get(), outside_compilation_attr_name));
+
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_host_graph_for_",
@@ -801,6 +805,11 @@
},
&fbody));
std::unique_ptr<FunctionBody> fbody_deleter(fbody);
+
+ // Preprocess edges between different outside compilations. They will be
+ // restored in `ConstructHostGraph()`.
+ TF_RETURN_IF_ERROR(PreprocessEdgesBetweenOutsideCompilations(
+ fbody->graph, outside_compilation_attr_name));
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
absl::StrCat("extract_outside_compilation_for_func_before_", func_name),
@@ -860,8 +869,9 @@
// Construct host graph.
if (!outside_compilation_host_graphs.empty()) {
- TF_RETURN_IF_ERROR(ConstructHostGraph(
- xla_cluster_name, outside_compilation_host_graphs, fld, host_graph));
+ TF_RETURN_IF_ERROR(
+ ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name,
+ outside_compilation_host_graphs, fld, host_graph));
}
// Remove the outside compilation graphs from function library.
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
index c5bd64f..bff9561 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
@@ -290,21 +290,18 @@
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shapes", &shapes));
EXPECT_EQ(shapes.size(), 1);
EXPECT_EQ(shapes[0].dim_size(), 1);
- // Check XlaHostCompute nodes' "shape_inference_graph" attr. "0" should have a
- // non-empty value, and "1" should have an empty value.
+ // Check XlaHostCompute nodes' "shape_inference_graph" attr. Both should have
+ // empty values.
string shape_inference_graph;
TF_CHECK_OK(GetNodeAttr(host_compute_0->attrs(), "shape_inference_graph",
&shape_inference_graph));
- EXPECT_EQ(shape_inference_graph,
- "_outside_compilation_shape_inference_cluster_0");
+ EXPECT_EQ(shape_inference_graph, "");
TF_CHECK_OK(GetNodeAttr(host_compute_1->attrs(), "shape_inference_graph",
&shape_inference_graph));
EXPECT_EQ(shape_inference_graph, "");
// Check `shape_inference_graphs`.
- EXPECT_EQ(shape_inference_graphs.size(), 1);
- EXPECT_EQ(shape_inference_graphs[0],
- "_outside_compilation_shape_inference_cluster_0");
+ EXPECT_EQ(shape_inference_graphs.size(), 0);
// Check `host_graph`: verify we have key placeholder and sequencer.
Node *key_placeholder = nullptr, *sequencer = nullptr;
@@ -333,8 +330,8 @@
send_recv_nodes.push_back(n);
}
}
- EXPECT_EQ(num_send_from_host, 2);
- EXPECT_EQ(num_recv_at_host, 2);
+ EXPECT_EQ(num_send_from_host, 1);
+ EXPECT_EQ(num_recv_at_host, 1);
for (Node *n : send_recv_nodes) {
Node *input_node;
TF_CHECK_OK(n->input_node(n->num_inputs() - 1, &input_node));
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 60b962d..2579643 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -72,6 +72,11 @@
// to resort to a dummy implementation. Currently Assert and CheckNumerics ops
// have dummy XLA implementations.
bool allow_dummy_ops;
+
+ // Whether ops that produce or consume DT_VARIANT values are allowed. We
+ // don't auto-cluster these ops because we don't yet support live-in or
+ // live-out DT_VARIANT values.
+ bool allow_ops_producing_or_consuming_variant;
};
bool IsDummyImplOp(absl::string_view op_name) {
@@ -84,6 +89,12 @@
op_name == "TruncatedNormal";
}
+bool OpProducesOrConsumesVariant(const Node& node) {
+ auto is_variant = [](DataType dtype) { return dtype == DT_VARIANT; };
+ return absl::c_any_of(node.input_types(), is_variant) ||
+ absl::c_any_of(node.output_types(), is_variant);
+}
+
bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
// is really a kind of function call and will be handled by
@@ -246,6 +257,10 @@
if (!op_filter.allow_dummy_ops && IsDummyImplOp(node->type_string())) {
return false;
}
+ if (!op_filter.allow_ops_producing_or_consuming_variant &&
+ OpProducesOrConsumesVariant(*node)) {
+ return false;
+ }
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, depth + 1,
lib_runtime)) {
@@ -470,16 +485,15 @@
XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration));
DeviceType jit_device_type(registration->compilation_device_name);
+ bool always_auto_cluster = registration->autoclustering_policy ==
+ XlaOpRegistry::AutoclusteringPolicy::kAlways;
+
OperationFilter op_filter;
op_filter.allow_resource_ops = registration->compile_resource_ops;
- op_filter.allow_stateful_rng_ops =
- (registration->autoclustering_policy ==
- XlaOpRegistry::AutoclusteringPolicy::kAlways);
- op_filter.allow_control_trigger =
- (registration->autoclustering_policy ==
- XlaOpRegistry::AutoclusteringPolicy::kAlways);
- op_filter.allow_dummy_ops = (registration->autoclustering_policy ==
- XlaOpRegistry::AutoclusteringPolicy::kAlways);
+ op_filter.allow_stateful_rng_ops = always_auto_cluster;
+ op_filter.allow_control_trigger = always_auto_cluster;
+ op_filter.allow_dummy_ops = always_auto_cluster;
+ op_filter.allow_ops_producing_or_consuming_variant = always_auto_cluster;
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, op_filter, 0,
@@ -503,6 +517,12 @@
<< node->type_string() << ")";
continue;
}
+ if (!op_filter.allow_ops_producing_or_consuming_variant &&
+ OpProducesOrConsumesVariant(*node)) {
+ VLOG(2) << "Rejecting " << node->name()
+ << ": produces or consumes DT_VARIANT";
+ continue;
+ }
if (!op_filter.allow_resource_ops &&
(HasResourceOutput(*node) || IsNonResourceVarResourceOp(*node))) {
@@ -639,6 +659,7 @@
op_filter.allow_stateful_rng_ops = true;
op_filter.allow_control_trigger = true;
op_filter.allow_dummy_ops = true;
+ op_filter.allow_ops_producing_or_consuming_variant = true;
return IsCompilableCall(ndef, jit_device_type, op_filter, 0, flr);
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 24d78c0..bf2c550 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -22,6 +22,7 @@
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
@@ -1147,5 +1148,80 @@
EXPECT_EQ(clusters["test/check"], "");
}
+TEST(XlaCompilationTest, DontAutoClusterOpsProducingVariant) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
+ Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
+
+ Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
+ Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
+
+ Output tensor_list_reserve = ops::TensorListReserve(
+ root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["test/tensor_list_reserve"], "");
+}
+
+TEST(XlaCompilationTest, DontAutoClusterOpsConsumingVariant) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output dummy_input =
+ ops::Placeholder(root.WithOpName("test/dummy_input"), DT_INT64);
+ Output variant_input =
+ ops::Placeholder(root.WithOpName("test/variant_input"), DT_VARIANT);
+
+ // Create one more node so that we don't avoid creating a cluster solely
+ // because it would be trivial.
+ Output dummy_cast =
+ ops::Cast(root.WithOpName("test/dummy_cast"), dummy_input, DT_INT32);
+
+ Output tensor_list_element_shape = ops::TensorListElementShape(
+ root.WithOpName("test/tensor_list_element_shape"), variant_input,
+ DT_INT32);
+
+ root.graph()->AddControlEdge(dummy_cast.node(),
+ tensor_list_element_shape.node());
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["test/tensor_list_element_shape"], "");
+}
+
+TEST(XlaCompilationTest, ClusterOpsProducingVariantIfOnXlaDevice) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ Output a = ops::Placeholder(root.WithOpName("test/a"), DT_INT64);
+ Output b = ops::Placeholder(root.WithOpName("test/b"), DT_INT64);
+
+ Output cast_a = ops::Cast(root.WithOpName("test/cast_a"), a, DT_INT32);
+ Output cast_b = ops::Cast(root.WithOpName("test/cast_b"), b, DT_INT32);
+
+ Output tensor_list_reserve = ops::TensorListReserve(
+ root.WithOpName("test/tensor_list_reserve"), cast_a, cast_b, DT_FLOAT);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ string xla_cpu_device = "/job:worker/replica:0/task:0/device:XLA_CPU:0";
+ for (Node* n : graph->nodes()) {
+ if (absl::StartsWith(n->name(), /*prefix=*/"test/")) {
+ n->set_assigned_device_name(xla_cpu_device);
+ }
+ }
+
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_NE(clusters["test/tensor_list_reserve"], "");
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
index d56d0f8..64a3301 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -34,15 +34,9 @@
//
// It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to
// make this more direct, but probably not worth it solely for this test.
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices));
- auto delete_devices = gtl::MakeCleanup([&] {
- for (Device* d : devices) {
- delete d;
- }
- });
-
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.session_options = session_options;
diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD
index f722245..64409d9 100644
--- a/tensorflow/compiler/jit/ops/BUILD
+++ b/tensorflow/compiler/jit/ops/BUILD
@@ -18,3 +18,9 @@
out = "xla_ops.py",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
+
+py_library(
+ name = "xla_ops_grad",
+ srcs = ["xla_ops_grad.py"],
+ deps = ["//tensorflow/python:framework_ops"],
+)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/compiler/jit/ops/xla_ops_grad.py
similarity index 61%
rename from tensorflow/contrib/estimator/python/estimator/dnn.py
rename to tensorflow/compiler/jit/ops/xla_ops_grad.py
index 10f657d..2d31d8d 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn.py
+++ b/tensorflow/compiler/jit/ops/xla_ops_grad.py
@@ -1,3 +1,4 @@
+"""Gradients for XLA ops."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,21 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""dnn python module.
-
-Importing from tensorflow.python.estimator is unsupported
-and will soon break!
-"""
-# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow_estimator.contrib.estimator.python.estimator import dnn
+from tensorflow.python.framework import ops
-# Include attrs that start with single underscore.
-_HAS_DYNAMIC_ATTRIBUTES = True
-dnn.__all__ = [s for s in dir(dnn) if not s.startswith('__')]
-from tensorflow_estimator.contrib.estimator.python.estimator.dnn import *
+@ops.RegisterGradient("XlaClusterOutput")
+def _XlaClusterOutputGrad(_, grad):
+ del grad # unused
+ raise RuntimeError("Gradient computation of graph in xla.compile() is "
+ "prohibited because it can cause performance degradation."
+ "Please move gradient computation inside xla.compile().")
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index 1fc5da5..38a54cc 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -386,7 +386,7 @@
TF_ASSERT_OK(s.ToGraph(graph.get()));
// This is needed to register the XLA_GPU device.
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_ASSERT_OK(DeviceFactory::AddDevices(
SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
@@ -400,10 +400,6 @@
TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
-
- for (Device* d : devices) {
- delete d;
- }
}
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 9006dd5..7df898a 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -31,12 +31,12 @@
class XlaCpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override;
+ std::vector<std::unique_ptr<Device>>* devices) override;
};
-Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
- const string& name_prefix,
- std::vector<Device*>* devices) {
+Status XlaCpuDeviceFactory::CreateDevices(
+ const SessionOptions& session_options, const string& name_prefix,
+ std::vector<std::unique_ptr<Device>>* devices) {
XlaDeviceFlags* flags = GetXlaDeviceFlags();
bool compile_on_demand = flags->tf_xla_compile_on_demand;
@@ -63,8 +63,7 @@
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_CPU_XLA_JIT;
options.use_multiple_streams = false;
- auto device = absl::make_unique<XlaDevice>(session_options, options);
- devices->push_back(device.release());
+ devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 4419701..944f732 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -29,12 +29,12 @@
class XlaGpuDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override;
+ std::vector<std::unique_ptr<Device>>* devices) override;
};
-Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& session_options,
- const string& name_prefix,
- std::vector<Device*>* devices) {
+Status XlaGpuDeviceFactory::CreateDevices(
+ const SessionOptions& session_options, const string& name_prefix,
+ std::vector<std::unique_ptr<Device>>* devices) {
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
registration.autoclustering_policy =
@@ -70,7 +70,7 @@
return status;
}
- devices->push_back(device.release());
+ devices->push_back(std::move(device));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc
index e828bae..4007309 100644
--- a/tensorflow/compiler/jit/xla_interpreter_device.cc
+++ b/tensorflow/compiler/jit/xla_interpreter_device.cc
@@ -33,12 +33,12 @@
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override;
+ std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
- std::vector<Device*>* devices) {
+ std::vector<std::unique_ptr<Device>>* devices) {
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
(void)registrations;
@@ -61,8 +61,7 @@
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
options.use_multiple_streams = false;
- auto device = absl::make_unique<XlaDevice>(session_options, options);
- devices->push_back(device.release());
+ devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
return Status::OK();
}
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 2b88a64..bc3d60b 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -376,27 +376,6 @@
)
tf_xla_py_test(
- name = "resampler_ops_test",
- size = "small",
- srcs = ["resampler_ops_test.py"],
- disabled_backends = [
- # TODO(b/74459949) Support BatchDot in CPU backend.
- "cpu",
- "cpu_ondemand",
- ],
- # TODO(b/112295522): figure out how to make OSS build pass.
- tags = ["no_oss"],
- deps = [
- ":xla_test",
- "//tensorflow/contrib/resampler:resampler_ops",
- "//tensorflow/contrib/resampler:resampler_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:platform_test",
- ],
-)
-
-tf_xla_py_test(
name = "dynamic_stitch_test",
size = "small",
srcs = ["dynamic_stitch_test.py"],
diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py
index 532e2b5..5d5e486 100644
--- a/tensorflow/compiler/tests/categorical_op_test.py
+++ b/tensorflow/compiler/tests/categorical_op_test.py
@@ -27,6 +27,7 @@
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import googletest
@@ -56,7 +57,7 @@
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
- with self.cached_session() as sess, self.test_scope():
+ with self.cached_session(), self.test_scope():
random_seed.set_random_seed(1618)
op = random_ops.multinomial(logits, num_samples,
output_dtype=dtypes.int32)
@@ -79,7 +80,7 @@
def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value.
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
x = rng(dtype, output_dtype)
@@ -107,7 +108,7 @@
def testCategoricalIsInRange(self):
for dtype in self.float_types:
for output_dtype in self.output_dtypes():
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
x = random_ops.multinomial(
array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
@@ -138,6 +139,57 @@
chi2 = self._chi2(probs, freqs)
self.assertLess(chi2, 1e-3)
+ def testStatelessMultinomialIsInRange(self):
+ for dtype in self.float_types:
+ for output_dtype in self.output_dtypes():
+ with self.cached_session() as sess:
+ with self.test_scope():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ x = stateless_random_ops.stateless_multinomial(
+ array_ops.ones(shape=[1, 20], dtype=dtype),
+ 1000,
+ seed_t,
+ output_dtype=output_dtype)
+ y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
+ self.assertTrue((y >= 0).sum() == 1000)
+ self.assertTrue((y < 20).sum() == 1000)
+
+ def testDeterminismMultinomial(self):
+ # Stateless values should be equal iff the seeds are equal (roughly)
+ num_samples = 10
+ with self.cached_session(), self.test_scope():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ seeds = [(x, y) for x in range(5) for y in range(5)] * 3
+ for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
+ [0.25, 0.75]]):
+ pure = stateless_random_ops.stateless_multinomial(
+ logits, num_samples, seed=seed_t)
+ values = [(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds]
+ for s0, v0 in values:
+ for s1, v1 in values:
+ self.assertEqual(s0 == s1, np.all(v0 == v1))
+
+ def testEmpty(self):
+ with self.cached_session():
+ with self.test_scope():
+ x = random_ops.multinomial(
+ array_ops.zeros([42, 40]), 0, output_dtype=dtypes.int32)
+ y = self.evaluate(x)
+ self.assertEqual(y.shape, (42, 0))
+
+ def testEmptyStateless(self):
+ with self.cached_session() as sess:
+ with self.test_scope():
+ seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
+ x = stateless_random_ops.stateless_multinomial(
+ array_ops.zeros([42, 40]),
+ 0,
+ seed=seed_t,
+ output_dtype=dtypes.int32)
+ y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
+ self.assertEqual(y.shape, (42, 0))
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
index deb9ac1..2187f57 100644
--- a/tensorflow/compiler/tests/concat_ops_test.py
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -254,7 +254,7 @@
def DISABLED_testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7)
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
for shape0 in (), (2,):
axis = len(shape0)
@@ -270,7 +270,7 @@
self.assertAllEqual(c.eval(), correct)
# Check gradients
dc = np.random.randn(*c.get_shape().as_list())
- dxs = sess.run(gradients_impl.gradients(c, xs, dc))
+ dxs = self.evaluate(gradients_impl.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
def testConcatTuple(self):
@@ -330,7 +330,7 @@
class ConcatOffsetTest(xla_test.XLATestCase):
def testBasic(self):
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
@@ -344,7 +344,7 @@
class PackTest(xla_test.XLATestCase):
def testBasic(self):
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@@ -354,7 +354,7 @@
self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
def testScalars(self):
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
s0 = constant_op.constant(2, dtypes.int32)
s1 = constant_op.constant(3, dtypes.int32)
@@ -364,7 +364,7 @@
self.assertAllEqual(ans, [2, 3, 5])
def testEmpty(self):
- with self.cached_session() as sess:
+ with self.cached_session():
with self.test_scope():
s0 = constant_op.constant([[]], dtypes.int32)
s1 = constant_op.constant([[]], dtypes.int32)
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index d1b90f0..bf5ea7b 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -42,7 +42,7 @@
def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr."""
- return any([substr in x for x in labels])
+ return any(substr in x for x in labels)
class DenseLayerTest(test.TestCase):
@@ -72,7 +72,7 @@
x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32)
y = layers.dense(x, 3)
- sess.run(variables.initialize_all_variables())
+ self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,
@@ -97,7 +97,7 @@
with jit_scope():
y = layers.dense(x, 3)
- sess.run(variables.initialize_all_variables())
+ self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,
@@ -126,7 +126,7 @@
with jit_scope():
y = layers.dense(x, 3)
- sess.run(variables.initialize_all_variables())
+ self.evaluate(variables.initialize_all_variables())
run_metadata = config_pb2.RunMetadata()
test_utils.RunWithWarmup(
sess,
diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py
index 50b04da..e89cf97 100644
--- a/tensorflow/compiler/tests/dynamic_stitch_test.py
+++ b/tensorflow/compiler/tests/dynamic_stitch_test.py
@@ -58,6 +58,15 @@
[idx1, idx2], [val1, val2],
expected=np.array([[], [], [], []], np.int32))
+ def testEmptyIndex(self):
+ idx1 = np.array([], dtype=np.int32)
+ idx2 = np.array([[], []], dtype=np.int32)
+ val1 = np.ndarray(shape=(0, 9), dtype=np.int32)
+ val2 = np.ndarray(shape=(2, 0, 9), dtype=np.int32)
+ self._AssertDynamicStitchResultIs([idx1, idx2], [val1, val2],
+ expected=np.ndarray(
+ shape=(0, 9), dtype=np.int32))
+
def testSimple1D(self):
val1 = np.array([0, 4, 7], dtype=np.int32)
val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 76706ad..2af32b5 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -101,7 +101,7 @@
self.assertAllEqual(15, product)
# Run some ops graphly
- with context.graph_mode(), self.cached_session() as sess:
+ with context.graph_mode(), self.cached_session():
with self.test_scope():
three = constant_op.constant(3)
five = constant_op.constant(5)
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index 61abf9c..0edd0c3 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -158,6 +158,23 @@
self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
+ def testRFFT3DMismatchedSize(self):
+
+ def _to_expected(x):
+ return np.fft.rfftn(
+ x,
+ axes=(-3, -2, -1),
+ s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
+
+ def _tf_fn(x):
+ return signal.rfft3d(
+ x,
+ fft_length=[
+ x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
+ ])
+
+ self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn)
+
def testIRFFT(self):
def _tf_fn(x):
@@ -202,6 +219,30 @@
self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
+ def testIRFFT3DMismatchedSize(self):
+
+ def _to_input(x):
+ return np.fft.rfftn(
+ np.real(x),
+ axes=(-3, -2, -1),
+ s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
+
+ def _to_expected(x):
+ return np.fft.irfftn(
+ x,
+ axes=(-3, -2, -1),
+ s=[x.shape[-3] // 2, x.shape[-2], x.shape[-1] * 2])
+
+ def _tf_fn(x):
+ return signal.irfft3d(
+ x,
+ fft_length=[
+ x.shape[-3].value // 2, x.shape[-2].value, x.shape[-1].value * 2
+ ])
+
+ self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn)
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
index dd9b7f3..a61827c 100644
--- a/tensorflow/compiler/tests/function_test.py
+++ b/tensorflow/compiler/tests/function_test.py
@@ -40,7 +40,7 @@
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
- with self.cached_session() as sess:
+ with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -66,7 +66,7 @@
bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
expected = APlus2B(aval, bval)
- with self.cached_session() as sess:
+ with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
@@ -90,7 +90,7 @@
bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
expected = Func(aval, bval)
- with self.cached_session() as sess:
+ with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def Foo(a, b):
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 6f51ae3..dbea984 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -75,7 +75,7 @@
def InLabels(labels, substr):
"""Returns true iff one of the labels contains substr."""
- return any([substr in x for x in labels])
+ return any(substr in x for x in labels)
def MetadataHasXlaRunOp(run_metadata):
diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py
index 5862211..0210201 100644
--- a/tensorflow/compiler/tests/listdiff_op_test.py
+++ b/tensorflow/compiler/tests/listdiff_op_test.py
@@ -33,13 +33,13 @@
def _testListDiff(self, x, y, out, idx):
for dtype in [dtypes.int32, dtypes.int64]:
for index_dtype in [dtypes.int32, dtypes.int64]:
- with self.cached_session() as sess:
+ with self.cached_session():
x_tensor = ops.convert_to_tensor(x, dtype=dtype)
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
with self.test_scope():
out_tensor, idx_tensor = array_ops.listdiff(
x_tensor, y_tensor, out_idx=index_dtype)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor])
self.assertAllEqual(out, tf_out)
self.assertAllEqual(idx, tf_idx)
self.assertEqual(1, out_tensor.get_shape().ndims)
diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py
index fd02a50..776ed89 100644
--- a/tensorflow/compiler/tests/lstm_test.py
+++ b/tensorflow/compiler/tests/lstm_test.py
@@ -89,7 +89,7 @@
# Initialize variables and run the unrolled LSTM step.
self.evaluate(variables.global_variables_initializer())
- return sess.run([m, c])
+ return self.evaluate([m, c])
def testLSTMCell(self):
# Run with all-0 weights, no padding.
@@ -174,7 +174,7 @@
# Initialize variables and run the unrolled LSTM layer.
self.evaluate(variables.global_variables_initializer())
- return sess.run(out_seq)
+ return self.evaluate(out_seq)
def testLSTMLayer(self):
# Run with all-0 weights, no padding.
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 1e91390..97ffad3 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -111,7 +111,7 @@
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x, sess=sess):
- return sess.run(special_math.ndtri(x))
+ return self.evaluate(special_math.ndtri(x))
a = -2.
b = 2.
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index a6b5802..d23fd12 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -3382,10 +3382,10 @@
}
// XLA devices register kernels at construction time; create all known devices
// to make sure the kernels are registered.
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
tensorflow::SessionOptions(), "", &devices));
- tensorflow::DeviceMgr device_mgr(devices);
+ tensorflow::DeviceMgr device_mgr(std::move(devices));
tensorflow::Device* ignored;
TF_QCHECK_OK(
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
index 132c59c..e8fc81b 100644
--- a/tensorflow/compiler/tests/reduce_ops_test.py
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -91,6 +91,7 @@
np.array([], dtype=np.bool).reshape(0, 3),
np.array([[False, True, False], [True, True, False]]),
]
+ ONES = [np.ones([34000, 2])]
def testReduceSumF32(self, index_dtype):
self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.REAL_DATA,
@@ -149,6 +150,11 @@
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
self.NONEMPTY_REAL_DATA, index_dtype)
+ def testReduceMeanF16(self, index_dtype):
+ if np.float16 in self.all_types:
+ self._testReduction(math_ops.reduce_mean, np.mean, np.float16, self.ONES,
+ index_dtype)
+
def testReduceMeanC64(self, index_dtype):
self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
self.NONEMPTY_COMPLEX_DATA, index_dtype)
diff --git a/tensorflow/compiler/tests/rmsprop_test.py b/tensorflow/compiler/tests/rmsprop_test.py
index 5138a4a..dc3e90b 100644
--- a/tensorflow/compiler/tests/rmsprop_test.py
+++ b/tensorflow/compiler/tests/rmsprop_test.py
@@ -76,7 +76,7 @@
rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered)
rms_update = rms_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
mg0 = rms_opt.get_slot(var0, "mg")
self.assertEqual(mg0 is not None, centered)
@@ -97,7 +97,7 @@
# Run 3 steps of RMSProp
for _ in range(3):
- rms_update.run()
+ self.evaluate(rms_update)
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
var0_np,
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
index 897db38..17639bd 100644
--- a/tensorflow/compiler/tests/scan_ops_test.py
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -71,7 +71,7 @@
class CumsumTest(xla_test.XLATestCase):
- valid_dtypes = [np.float32]
+ valid_dtypes = [np.float32, np.int32]
def axis_dtypes(self):
return set(self.int_types).intersection([np.int32, np.int64])
@@ -149,7 +149,7 @@
class CumprodTest(xla_test.XLATestCase):
- valid_dtypes = [np.float32]
+ valid_dtypes = [np.float32, np.int32]
def axis_dtypes(self):
return set(self.int_types).intersection([np.int32, np.int64])
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index 21708aa..ee7ca7e 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -156,7 +156,7 @@
return math.exp(-(x**2) / 2.) / math.sqrt(2 * math.pi)
def probit(x, sess=sess):
- return sess.run(special_math.ndtri(x))
+ return self.evaluate(special_math.ndtri(x))
a = -2.
b = 2.
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index d612d3b..95c9e7ff 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -481,6 +481,72 @@
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
+ def quantize_and_dequantize_v2_round_half_up(x):
+ return array_ops.quantize_and_dequantize_v2(
+ x,
+ -1,
+ 1.0,
+ signed_input=True,
+ num_bits=8,
+ range_given=True,
+ round_mode="HALF_UP")
+
+ self._assertOpOutputMatchesExpected(
+ quantize_and_dequantize_v2_round_half_up,
+ np.array([-0.8, -0.5, 0, 0.3, 0.8, -2, 33], dtype=dtype),
+ expected=np.array([
+ -102.0 / 127,
+ -63.0 / 127,
+ 0,
+ 38.0 / 127,
+ 102.0 / 127,
+ -128.0 / 127,
+ 1,
+ ],
+ dtype=dtype))
+
+ def quantize_and_dequantize_v2_round_half_to_even(x):
+ return array_ops.quantize_and_dequantize_v2(
+ x,
+ -1.0,
+ 1.0,
+ signed_input=True,
+ num_bits=8,
+ range_given=True,
+ round_mode="HALF_TO_EVEN")
+
+ self._assertOpOutputMatchesExpected(
+ quantize_and_dequantize_v2_round_half_to_even,
+ np.array(
+ [
+ -0.8,
+ # The -0.5 should become -63.5 after scaling and with
+ # rounding this should become -64. But with the test
+ # unary_ops_test_cpu_ondemand, this fails as the result
+ # before scaling becomes -63.499996 and gets rounded to -63.
+ # TODO(sreenik): Some one more familiar with this test needs
+ # to take a look and resolve this. This works on all other
+ # variations of the platform like cpu, and gpu.
+ # -0.5,
+ 0,
+ 0.3,
+ 0.8,
+ -2,
+ 33
+ ],
+ dtype=dtype),
+ expected=np.array(
+ [
+ -102.0 / 127,
+ # -64.0 / 127,
+ 0,
+ 38.0 / 127,
+ 102.0 / 127,
+ -128.0 / 127,
+ 1,
+ ],
+ dtype=dtype))
+
def quantize_and_dequantize_v3(x):
return array_ops.quantize_and_dequantize_v3(
x, -127, 127, num_bits=8, signed_input=True, range_given=False)
diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py
index e776c8a..fcd7ac5 100644
--- a/tensorflow/compiler/tests/variable_ops_test.py
+++ b/tensorflow/compiler/tests/variable_ops_test.py
@@ -77,7 +77,7 @@
sess.run(variables.variables_initializer([v]))
x = v.sparse_read(2)
self.assertAllClose(
- np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x))
+ np.array([8j, 9, 10, 11]).astype(dtype), self.evaluate(x))
def testSparseRead1DIndices(self):
for dtype in self.numeric_types:
@@ -89,7 +89,7 @@
x = v.sparse_read([2, 1])
self.assertAllClose(
np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype),
- sess.run(x))
+ self.evaluate(x))
def testSparseRead2DIndices(self):
for dtype in self.numeric_types:
@@ -102,7 +102,7 @@
self.assertAllClose(
np.array([[[8, 9, 10, 11], [4, 5, 6, 7]],
[[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype),
- sess.run(x))
+ self.evaluate(x))
def testSparseRead2DIndices3DTensor(self):
for dtype in self.numeric_types:
@@ -115,9 +115,9 @@
x = v.sparse_read([[2, 1], [3, 0]])
self.assertAllClose(
np.array(
- [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]
- ], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
- ],).astype(dtype), sess.run(x))
+ [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]],
+ [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]
+ ],).astype(dtype), self.evaluate(x))
def testShape(self):
for dtype in self.numeric_types:
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index 28d61fb..ef55292 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -81,7 +81,7 @@
with self.cached_session() as sess:
with self.test_scope():
x = gen_control_flow_ops.control_trigger()
- sess.run(x)
+ self.evaluate(x)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 486b4d8..25a84fb 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -9,6 +9,7 @@
"//tensorflow/compiler/jit/...",
"//tensorflow/compiler/tests/...",
"//tensorflow/compiler/tf2xla/...",
+ "//tensorflow/contrib/compiler/...",
],
)
@@ -211,7 +212,6 @@
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
- "//tensorflow/compiler/xla/client/lib:numeric",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
index 2db2514..795ea09 100644
--- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
@@ -50,7 +50,7 @@
return;
}
- const XlaExpression& arg = XlaContext::Get(ctx).args()[index_];
+ const XlaExpression& arg = ctx->xla_context()->args()[index_];
OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid,
errors::InvalidArgument("Invalid/missing argument expression"));
ctx->SetOutputExpression(0, arg);
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index a267c0c..0e2f335 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -115,9 +115,9 @@
// operators. For now, cast everything to the statistics type (which
// may be more precise than the input type).
auto grad_backprop =
- XlaHelpers::ConvertElementType(b, ctx->Input(0), scale_dtype);
+ XlaHelpers::ConvertElementType(ctx->Input(0), scale_dtype);
auto activations =
- XlaHelpers::ConvertElementType(b, ctx->Input(1), scale_dtype);
+ XlaHelpers::ConvertElementType(ctx->Input(1), scale_dtype);
auto scale = ctx->Input(2);
auto mean = ctx->Input(3);
auto var = ctx->Input(4);
@@ -151,11 +151,11 @@
const DataType accumulation_type =
XlaHelpers::SumAccumulationType(scale_dtype);
auto converted =
- XlaHelpers::ConvertElementType(b, grad_backprop, accumulation_type);
+ XlaHelpers::ConvertElementType(grad_backprop, accumulation_type);
auto reduce =
xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
- offset_backprop = XlaHelpers::ConvertElementType(b, reduce, scale_dtype);
+ offset_backprop = XlaHelpers::ConvertElementType(reduce, scale_dtype);
// scratch1 = rsqrt(pop_var + epsilon)
auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5);
@@ -165,19 +165,18 @@
// scratch2 = sum(y_backprop * (x - mean))
auto mul =
xla::Mul(grad_backprop, xla::Sub(activations, mean, {feature_index}));
- converted = XlaHelpers::ConvertElementType(b, mul, accumulation_type);
+ converted = XlaHelpers::ConvertElementType(mul, accumulation_type);
reduce =
xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), reduction_dims);
- auto scratch2 = XlaHelpers::ConvertElementType(b, reduce, scale_dtype);
+ auto scratch2 = XlaHelpers::ConvertElementType(reduce, scale_dtype);
x_backprop =
xla::Mul(grad_backprop, xla::Mul(scratch1, scale), {feature_index});
scale_backprop = xla::Mul(scratch1, scratch2);
}
- ctx->SetOutput(0,
- XlaHelpers::ConvertElementType(b, x_backprop, input_dtype));
+ ctx->SetOutput(0, XlaHelpers::ConvertElementType(x_backprop, input_dtype));
ctx->SetOutput(1, scale_backprop);
ctx->SetOutput(2, offset_backprop);
ctx->SetConstantOutput(3, Tensor());
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
index 41f5405..e7f369b 100644
--- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -107,11 +107,11 @@
const DataType accumulation_type =
XlaHelpers::SumAccumulationType(input_type(0));
auto converted =
- XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type);
+ XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type);
auto reduce =
xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), reduce_dims);
- ctx->SetOutput(0, XlaHelpers::ConvertElementType(b, reduce, input_type(0)));
+ ctx->SetOutput(0, XlaHelpers::ConvertElementType(reduce, input_type(0)));
}
private:
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index ad85940..7199b9b 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -21,10 +21,13 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
@@ -57,11 +60,9 @@
const int64 batch_size = logits_shape.dim_size(0);
const int64 num_classes = logits_shape.dim_size(1);
- xla::XlaBuilder* builder = ctx->builder();
-
xla::Shape uniform_shape;
int class_dimension;
- if (num_samples > 1) {
+ if (num_samples != 1) {
std::array<int64, 3> uniform_shape_array = {
{batch_size, num_samples, num_classes}};
xla::PrimitiveType uniform_xla_type;
@@ -83,16 +84,16 @@
xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array);
class_dimension = 1;
}
- xla::XlaOp uniforms =
- xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)),
- XlaHelpers::One(builder, input_type(0)), uniform_shape);
+ xla::PrimitiveType type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &type));
+ xla::XlaOp log_uniforms = GetLogUniforms(uniform_shape, type, ctx);
// Use Gumbel softmax trick to generate categorical samples.
// See:
// https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
// TODO(b/68769470): Switch to using a cumulative sum approach.
auto softmax_entries =
- xla::Sub(logits, xla::Log(-xla::Log(uniforms)),
+ xla::Sub(logits, log_uniforms,
/*broadcast_dimensions=*/{0, class_dimension});
xla::PrimitiveType xla_output_type;
@@ -107,6 +108,16 @@
ctx->SetOutput(0, argmax);
}
+ virtual xla::XlaOp GetLogUniforms(xla::Shape uniform_shape,
+ xla::PrimitiveType type,
+ XlaOpKernelContext* ctx) {
+ xla::XlaBuilder* builder = ctx->builder();
+ auto uniforms =
+ xla::RngUniform(XlaHelpers::Zero(builder, input_type(0)),
+ XlaHelpers::One(builder, input_type(0)), uniform_shape);
+ return xla::Log(-xla::Log(uniforms));
+ }
+
private:
TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp);
};
@@ -115,5 +126,48 @@
REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstantInput("num_samples"),
CategoricalOp);
+class StatelessCategoricalOp : public CategoricalOp {
+ public:
+ explicit StatelessCategoricalOp(OpKernelConstruction* ctx)
+ : CategoricalOp(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ }
+
+ xla::XlaOp GetLogUniforms(xla::Shape uniform_shape, xla::PrimitiveType type,
+ XlaOpKernelContext* ctx) override {
+ xla::XlaOp seed = ctx->Input(2);
+ auto seed0 = xla::Reshape(xla::Slice(seed, {0}, {1}, {1}), {});
+ auto seed1 = xla::Reshape(xla::Slice(seed, {1}, {2}, {1}), {});
+
+ xla::XlaBuilder* builder = ctx->builder();
+ if (uniform_shape.element_type() == xla::BF16) {
+ uniform_shape.set_element_type(xla::F32);
+ }
+ auto uniforms = xla::StatelessRngUniform(
+ {seed0, seed1}, uniform_shape, XlaHelpers::Zero(builder, DT_FLOAT),
+ XlaHelpers::One(builder, DT_FLOAT));
+ return xla::ConvertElementType(xla::Log(-xla::Log(uniforms)), type);
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape seed_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2,
+ errors::InvalidArgument("seed must have shape [2], not ",
+ seed_shape.DebugString()));
+ CategoricalOp::Compile(ctx);
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(StatelessCategoricalOp);
+};
+
+REGISTER_XLA_OP(Name("StatelessMultinomial")
+ .CompileTimeConstantInput("num_samples")
+ .TypeConstraint("T", {DT_FLOAT, DT_BFLOAT16})
+ .TypeConstraint("Tseed", DT_INT32),
+ StatelessCategoricalOp);
+
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
index c9a1be4..641fefa 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -24,7 +24,6 @@
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/node_def_util.h"
@@ -65,60 +64,63 @@
// 0 0 1 1 0 0 0 0 1 1 0 0
// 0 0 0 0 1 1 0 0 0 0 1 1
//
-// The first step is to create a one tensor, A, that is [3]
-// 0 1 2
+// The first step is to create a iota A with iota_dimension = 2
+// 0 0 0 0 0 0 0 0 0 0 0 0
+// 1 1 1 1 1 1 1 1 1 1 1 1
+// 2 2 2 2 2 2 2 2 2 2 2 2
//
-// and another tensor, B, that is [3 * 2]
-// 0 1 2 3 4 5
+// 0 0 0 0 0 0 0 0 0 0 0 0
+// 1 1 1 1 1 1 1 1 1 1 1 1
+// 2 2 2 2 2 2 2 2 2 2 2 2
//
-// and divide B it by 2 to get
-// 0 0 1 1 2 2
+// and another iota B with iota_dimension = 3
+// 0 1 2 3 4 5 0 1 2 3 4 5
+// 0 1 2 3 4 5 0 1 2 3 4 5
+// 0 1 2 3 4 5 0 1 2 3 4 5
//
-// then we broadcast the B to [2, 2, 3, 3 * 2]
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 1 2 3 4 5 0 1 2 3 4 5
+// 0 1 2 3 4 5 0 1 2 3 4 5
+// 0 1 2 3 4 5 0 1 2 3 4 5
//
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
-// 0 0 1 1 2 2 0 0 1 1 2 2
+// and divide B by 2 to get
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
//
-// Finally compare A and broadcasted B in dimension 2 amd return the result at
-// the beginning of the comment.
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// Finally compare A and B and return the result at the beginning of the
+// comment.
xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape,
xla::XlaBuilder* builder) {
xla::Shape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
int64 depthwise_multiplier =
filter_shape.dimensions(filter_shape.dimensions_size() - 1);
- int64 input_feature =
- filter_shape.dimensions(filter_shape.dimensions_size() - 2);
- // Create a M sized linspace and an M*N sized linspace that will be
- // broadcasted into perpendicular dimensions and compared.
- xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature);
- xla::XlaOp expanded_feature_iota =
- xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier);
+ // Create two iotas with the shape of the expanded filter, one of them with
+ // the iota dimension chosen as the feature dimension, and the other a iota
+ // with the iota dimension chosen as the expanded output feature dimension.
+ std::vector<int64> iota_dimensions(expanded_filter_shape.dimensions().begin(),
+ expanded_filter_shape.dimensions().end());
+ xla::Shape iota_shape = xla::ShapeUtil::MakeShape(xla::S32, iota_dimensions);
+ xla::XlaOp input_feature_iota = xla::Iota(
+ builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 2);
+ xla::XlaOp expanded_feature_iota = xla::Iota(
+ builder, iota_shape, /*iota_dimension=*/iota_dimensions.size() - 1);
- // Divide the M*N sized linspace by the depthwise_multiplier to create
- // [0 0 1 1 2 2] in the example in the function comment.
+ // Divide 'expanded_feature_iota' by the depthwise_multiplier to create
+ // [0 0 1 1 2 2] ... in the example in the function comment.
expanded_feature_iota =
xla::Div(expanded_feature_iota,
XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
depthwise_multiplier));
- // Broadcast the N*M linspace to [H, W, ..., M, M*N].
- std::vector<int64> expanded_feature_broadcast_dims(
- expanded_filter_shape.dimensions().begin(),
- expanded_filter_shape.dimensions().end());
- expanded_feature_broadcast_dims.pop_back();
- auto broadcasted_expanded_feature_iota =
- xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims);
-
- // Compare the broadcasted linspace to the input feature linspace in the
- // input feature dimension to create a diagonal predicate.
- return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota,
- {expanded_filter_shape.dimensions_size() - 2});
+ // Compare 'input_feature_iota' with 'expanded_feature_iota' to create a
+ // diagonal predicate.
+ return xla::Eq(expanded_feature_iota, input_feature_iota);
}
// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
index b2f6ef4..6e6ba21 100644
--- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -113,8 +113,20 @@
}
}
int number_of_indices = max_index + 1;
- OP_REQUIRES(ctx, number_of_indices > 0,
- errors::InvalidArgument("no indices supplied"));
+ int64 result_rank = 1 + data0_shape.dims() - indices0_shape.dims();
+ if (number_of_indices == 0) {
+ std::vector<int64> result_shape(result_rank);
+ for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
+ result_shape[d - indices0_shape.dims() + 1] = data0_shape.dim_size(d);
+ }
+ xla::PrimitiveType element_type =
+ ctx->input_xla_type(ctx->num_inputs() - 1);
+ xla::Literal empty_literal = xla::Literal::CreateFromShape(
+ xla::ShapeUtil::MakeShape(element_type, result_shape));
+ ctx->SetOutput(0, xla::ConstantLiteral(ctx->builder(), empty_literal));
+ return;
+ }
+
// Construct the reverse mapping, for each index, of which slice of which
// input it comes from.
std::vector<int32> src_input_vector(number_of_indices);
@@ -157,12 +169,9 @@
// Set up the vectors for slicing: the first dimension will vary
// slice by slice, and the rest take the full common extra shape.
- std::vector<int64> slice_start(1 + data0_shape.dims() -
- indices0_shape.dims());
- std::vector<int64> slice_limit(1 + data0_shape.dims() -
- indices0_shape.dims());
- std::vector<int64> stride(1 + data0_shape.dims() - indices0_shape.dims(),
- 1);
+ std::vector<int64> slice_start(result_rank);
+ std::vector<int64> slice_limit(result_rank);
+ std::vector<int64> stride(result_rank, 1);
for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
index c68b0bf..29687c7 100644
--- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
@@ -17,7 +17,6 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
index cdba668..142be03 100644
--- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
@@ -260,19 +260,19 @@
xla::XlaOp below_min = xla::Lt(input, nudged_input_min);
xla::XlaOp select1 = xla::Select(below_min, gradient, zeroes);
xla::XlaOp reduce1 = xla::ReduceAll(
- XlaHelpers::ConvertElementType(b, select1, accumulation_type),
+ XlaHelpers::ConvertElementType(select1, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
- xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type);
+ xla::XlaOp output1 = XlaHelpers::ConvertElementType(reduce1, data_type);
ctx->SetOutput(1, output1);
xla::XlaOp above_max = xla::Gt(input, nudged_input_max);
xla::XlaOp select2 = xla::Select(above_max, gradient, zeroes);
xla::XlaOp reduce2 = xla::ReduceAll(
- XlaHelpers::ConvertElementType(b, select2, accumulation_type),
+ XlaHelpers::ConvertElementType(select2, accumulation_type),
XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type));
- xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type);
+ xla::XlaOp output2 = XlaHelpers::ConvertElementType(reduce2, data_type);
ctx->SetOutput(2, output2);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
index 9b06357..6df8b53 100644
--- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc
@@ -20,6 +20,7 @@
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -50,11 +51,36 @@
errors::InvalidArgument("input must be at least 1 dimensional"));
std::vector<int64> fft_length;
+ xla::XlaOp input = ctx->Input(0);
if (fft_type_ == FftType::RFFT || fft_type_ == FftType::IRFFT) {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &fft_length));
OP_REQUIRES(ctx, fft_length.size() == fft_rank_,
errors::InvalidArgument("fft_length must be length ",
fft_rank_, " vector"));
+
+ // Zero pad or truncate the axes we're doing FFT on.
+ absl::InlinedVector<int64, 4> slice_sizes = input_shape.dim_sizes();
+ std::vector<std::pair<int64, int64>> padding_sizes(slice_sizes.size());
+ std::vector<int64> expected_sizes = fft_length;
+ // IRFFT wants the innermost axis to be n / 2 + 1.
+ if (fft_type_ == FftType::IRFFT) {
+ expected_sizes[fft_rank_ - 1] = fft_length[fft_rank_ - 1] / 2 + 1;
+ }
+ for (int i = 0; i < fft_rank_; i++) {
+ int index = input_shape.dims() - fft_rank_ + i;
+ if (input_shape.dim_size(index) > expected_sizes[i]) {
+ slice_sizes[index] = expected_sizes[i];
+ } else {
+ padding_sizes[index].second =
+ expected_sizes[i] - input_shape.dim_size(index);
+ }
+ }
+
+ std::vector<int64> start_indices(input_shape.dims(), 0);
+ std::vector<int64> strides(input_shape.dims(), 1);
+ input = xla::Pad(xla::Slice(input, start_indices, slice_sizes, strides),
+ XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)),
+ xla::MakeEdgePaddingConfig(padding_sizes));
} else {
// Innermost axis provides the FFT length.
for (int i = 0; i < fft_rank_; i++) {
@@ -63,7 +89,7 @@
}
}
- xla::XlaOp fft = xla::Fft(ctx->Input(0), fft_type_, fft_length);
+ xla::XlaOp fft = xla::Fft(input, fft_type_, fft_length);
ctx->SetOutput(0, fft);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc
index 56da50f..b5e0839 100644
--- a/tensorflow/compiler/tf2xla/kernels/if_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc
@@ -72,7 +72,7 @@
arg.shape = resource->shape();
OP_REQUIRES(ctx, arg.initialized,
errors::Unimplemented("Uninitialized arguments: ", arg.name));
- arg.tensor_array_size = resource->tensor_array_size();
+ arg.max_array_size = resource->max_array_size();
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
index b49b251..e9bb0a7 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc
@@ -191,12 +191,11 @@
DataType type = context->input_type(0);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
- auto converted =
- XlaHelpers::ConvertElementType(b, input, accumulation_type);
+ auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
auto reduce = xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
*context->GetOrCreateAdd(accumulation_type),
{height_dim, width_dim});
- auto output = XlaHelpers::ConvertElementType(b, reduce, type);
+ auto output = XlaHelpers::ConvertElementType(reduce, type);
output =
xla::Div(output, XlaHelpers::FloatLiteral(b, type, height * width));
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index e310db2..e2c05b6 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -30,7 +30,9 @@
namespace tensorflow {
namespace {
-// The logic below uses a custom-call to implement argmax.
+// The logic below uses a custom-call to implement argmax when possible. When
+// custom-call is not allowed or input shapes are not supported, this kernel
+// falls back to using XLA HLO native ArgMax.
//
// Also see b/29507024 for first-class XLA support for indexing ops.
class ArgMaxCustomCallOp : public XlaOpKernel {
@@ -50,27 +52,40 @@
// overhead, when compiling ahead-of-time.
int64 dim;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim));
- OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
- OP_REQUIRES(
- ctx, dim < input_shape.dims(),
- errors::InvalidArgument("dim must be < input rank (",
- input_shape.dims(), "), but got: ", dim));
- const int64 dim_size = input_shape.dim_size(dim);
- OP_REQUIRES(ctx, dim_size > 0,
+
+ const int input_dims = input_shape.dims();
+ const int axis = dim < 0 ? dim + input_dims : dim;
+ OP_REQUIRES(ctx, axis >= 0 && axis < input_dims,
+ errors::InvalidArgument("Expected dimension in the range [",
+ -input_dims, ", ", input_dims,
+ "), but got ", dim));
+
+ const int64 axis_size = input_shape.dim_size(axis);
+ OP_REQUIRES(ctx, axis_size > 0,
errors::InvalidArgument(
"Reduction axis ", dim,
" is empty in shape: ", input_shape.DebugString()));
- // The output shape is the input shape contracted along dim.
- TensorShape output_shape;
- for (int d = 0; d < input_shape.dims() - 1; ++d) {
- output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
+ const DataType dtype = output_type(0);
+ xla::PrimitiveType output_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &output_type));
+
+ // Fall back to XLA ArgMax HLO when CustomCall is not allowed or when input
+ // shape isn't supported.
+ if (!ctx->compiler()->options().allow_cpu_custom_calls ||
+ (input_dims != 1 && input_dims != 2)) {
+ xla::XlaOp output = XlaHelpers::ArgMax(ctx->Input(0), output_type, axis);
+ ctx->SetOutput(0, output);
+ return;
}
- // For now we use a custom-call, only for the 1d and 2d cases.
- OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(),
- errors::InvalidArgument(
- "ArgMax implementation requires a CustomCall on CPU"));
+ xla::XlaOp output;
+ // The output shape is the input shape contracted along axis.
+ TensorShape output_shape;
+ for (int d = 0; d < input_shape.dims() - 1; ++d) {
+ output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
+ }
+
xla::XlaBuilder& b = *ctx->builder();
// XLA passes <out> to the function, so it is not included here.
@@ -84,7 +99,7 @@
args.push_back(xla::ConstantLiteral(
&b, xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
args.push_back(
- xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
+ xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(axis)));
}
// The argmax function expects row-major layout.
@@ -101,24 +116,15 @@
}
// Tell XLA to call the custom code, defined in
- // index_ops_kernel_argmax_float_1d.cc.
- xla::XlaOp output;
- switch (input_shape.dims()) {
- case 1:
- output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
- xla_shape, arg_shapes);
- break;
- case 2:
- output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
- xla_shape, arg_shapes);
- break;
- default:
- OP_REQUIRES(ctx, false,
- errors::Unimplemented(
- "Argmax is only implemented for 1d and 2d tensors"
- ", but got shape: ",
- input_shape.DebugString()));
+ // index_ops_kernel_argmax_float_{1, 2}d.cc.
+ if (input_dims == 1) {
+ output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
+ xla_shape, arg_shapes);
+ } else {
+ output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
+ xla_shape, arg_shapes);
}
+ output = xla::ConvertElementType(output, output_type);
ctx->SetOutput(0, output);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
index f028e36..93f0297 100644
--- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
@@ -37,12 +37,11 @@
// output = sum(t ** 2) / 2
const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype);
- auto t =
- XlaHelpers::ConvertElementType(b, ctx->Input(0), accumulation_type);
+ auto t = XlaHelpers::ConvertElementType(ctx->Input(0), accumulation_type);
auto square = xla::Mul(t, t);
auto reduce = xla::Reduce(square, XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), dims);
- auto deconverted = XlaHelpers::ConvertElementType(b, reduce, dtype);
+ auto deconverted = XlaHelpers::ConvertElementType(reduce, dtype);
auto two = XlaHelpers::IntegerLiteral(b, dtype, 2);
ctx->SetOutput(0, xla::Div(deconverted, two));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
index 87ee2d3..987901d 100644
--- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
@@ -49,16 +49,14 @@
// We use a window of depth_radius_ * 2 + 1, to account for the current
// element and a depth_radius_ on either side.
auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0));
- auto converted =
- XlaHelpers::ConvertElementType(builder, input, accumulation_type);
+ auto converted = XlaHelpers::ConvertElementType(input, accumulation_type);
auto squared = xla::Mul(converted, converted);
auto reduce = xla::ReduceWindow(
squared, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
/* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
- auto sqr_sum =
- XlaHelpers::ConvertElementType(builder, reduce, input_type(0));
+ auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0));
auto scale = xla::Pow(
xla::Add(xla::ConstantR0<float>(builder, bias_),
@@ -138,15 +136,14 @@
auto accumulation_type = XlaHelpers::SumAccumulationType(input_type(0));
auto converted =
- XlaHelpers::ConvertElementType(builder, in_image, accumulation_type);
+ XlaHelpers::ConvertElementType(in_image, accumulation_type);
auto squared = xla::Mul(converted, converted);
auto reduce = xla::ReduceWindow(
squared, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
/* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
- auto sqr_sum =
- XlaHelpers::ConvertElementType(builder, reduce, input_type(0));
+ auto sqr_sum = XlaHelpers::ConvertElementType(reduce, input_type(0));
auto norm =
xla::Add(xla::ConstantR0<float>(builder, bias_),
@@ -157,15 +154,13 @@
xla::Div(out_image, norm)),
in_grads);
- auto converted_dy =
- XlaHelpers::ConvertElementType(builder, dy, accumulation_type);
+ auto converted_dy = XlaHelpers::ConvertElementType(dy, accumulation_type);
auto dy_reduce = xla::ReduceWindow(
converted_dy, XlaHelpers::Zero(builder, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type),
/* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
/* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
- auto dy_reduced =
- XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0));
+ auto dy_reduced = XlaHelpers::ConvertElementType(dy_reduce, input_type(0));
xla::XlaOp gradients = xla::Add(
xla::Mul(in_image, dy_reduced),
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
index 8dfd7de..2dd0a71 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
@@ -16,8 +16,8 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
@@ -61,11 +61,11 @@
// Compute 'offset', which is how many diagonals we are above/below the
// diagonal.
- xla::XlaOp iota_m = xla::Iota(builder, index_xla_type, m);
- xla::XlaOp iota_n = xla::Iota(builder, index_xla_type, n);
+ xla::Shape iota_shape = xla::ShapeUtil::MakeShape(index_xla_type, {m, n});
+ xla::XlaOp iota_m = xla::Iota(builder, iota_shape, /*iota_dimension=*/0);
+ xla::XlaOp iota_n = xla::Iota(builder, iota_shape, /*iota_dimension=*/1);
- auto offset = xla::Sub(xla::Broadcast(iota_n, {m}), iota_m,
- /*broadcast_dimensions=*/{0});
+ auto offset = xla::Sub(iota_n, iota_m);
// If num_lower or num_upper are negative, include all lower/upper
// diagonals.
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
index c0ca881..4f980b6 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
@@ -16,7 +16,6 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
index 6f4ed49..7fe1024 100644
--- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
@@ -19,6 +19,7 @@
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/platform/macros.h"
@@ -26,12 +27,26 @@
namespace tensorflow {
namespace {
+enum QuantizerRoundMode {
+ // Round half up: if the fraction of y is exactly 0.5, then
+ // round(y) = y + 0.5
+ // E.g., -5.5 gets rounded to -5, -5.4 goes to -5,
+ // 5.4 goes to 5, and 5.5 goes to 6.
+ ROUND_HALF_UP,
+ // Round half to even: if the fraction of y is exactly 0.5, then round(y) is
+ // the nearest even integer to y.
+ // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes
+ // -24, and -24.5 gets rounded to 24.
+ ROUND_HALF_TO_EVEN,
+};
+
class QuantizeAndDequantizeOp : public XlaOpKernel {
public:
explicit QuantizeAndDequantizeOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
+ round_mode_ = ROUND_HALF_TO_EVEN;
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -117,8 +132,17 @@
// in that case they were measured from the tensor.
input = Clamp(min_range, input, max_range);
}
- xla::XlaOp result =
- Floor((input - min_range) * scale + half) * inverse_scale + min_range;
+ xla::XlaOp result;
+ switch (round_mode_) {
+ case ROUND_HALF_TO_EVEN: {
+ result = xla::RoundToEven(input * scale) * inverse_scale;
+ break;
+ }
+ case ROUND_HALF_UP: {
+ result = Floor(input * scale + half) * inverse_scale;
+ break;
+ }
+ }
ctx->SetOutput(0, result);
}
@@ -126,6 +150,7 @@
int64 num_bits_ = -1;
bool signed_input_;
bool range_given_;
+ QuantizerRoundMode round_mode_;
};
class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp {
@@ -136,6 +161,20 @@
OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
errors::InvalidArgument("num_bits is out of range: ", num_bits_,
" with signed_input_ ", signed_input_));
+ string round_mode_string;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string));
+ OP_REQUIRES(
+ ctx,
+ (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"),
+ errors::InvalidArgument("Round mode string must be "
+ "'HALF_UP' or "
+ "'HALF_TO_EVEN', is '" +
+ round_mode_string + "'"));
+ if (round_mode_string == "HALF_UP") {
+ round_mode_ = ROUND_HALF_UP;
+ } else if (round_mode_string == "HALF_TO_EVEN") {
+ round_mode_ = ROUND_HALF_TO_EVEN;
+ }
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
index 415ce9b..8822e29 100644
--- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -26,7 +26,6 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
index 132160d..65e158d 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -125,10 +125,9 @@
auto size = xla::GetDimensionSize(input, dimensions_to_reduce[i]);
divisor = xla::Mul(divisor, size);
}
- xla::PrimitiveType type;
- TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
- divisor = xla::ConvertElementType(divisor, type);
- return reduce_output / divisor;
+ divisor = xla::ConvertElementType(divisor, xla_reduction_type_);
+ return XlaHelpers::ConvertElementType(reduce_output / divisor,
+ input_type(0));
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
index 8f1667d..af716ea 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -50,8 +50,8 @@
// Applies a transformation to the output of the reduction. The desired
// computation should be added to 'builder'. Argument 'input' is the original
// input of the reduction; 'reduce_output' is the output of the reduction.
- // Returns the transformed reduction output, Defaults to returning
- // 'reduce_output' unchanged.
+ // Returns the transformed reduction output. Defaults to returning
+ // 'reduce_output' converted to the input type.
virtual xla::XlaOp BuildFinalizer(
xla::XlaBuilder* builder, const xla::XlaOp& input,
const xla::XlaOp& reduce_output,
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
index e96cabb..2ca2a85 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -35,13 +35,13 @@
ctx, DataTypeToPrimitiveType(reduction_type_, &xla_reduction_type_));
}
-// Unless BuildFinalizer is overridden the reduction has no
-// finalizer.
+// The default finalizer converts the results back into the input type. This can
+// be overridden.
xla::XlaOp XlaReductionOp::BuildFinalizer(
xla::XlaBuilder* /*builder*/, const xla::XlaOp& /*input*/,
const xla::XlaOp& reduce_output,
const std::vector<int64>& /*dimensions_to_reduce*/) {
- return reduce_output;
+ return XlaHelpers::ConvertElementType(reduce_output, input_type(0));
}
void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
@@ -117,8 +117,7 @@
xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie();
auto reduce = xla::Reduce(data, initial, reduction_computation, xla_axes);
- auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0));
- auto finalized = BuildFinalizer(b, data, deconverted, xla_axes);
+ auto finalized = BuildFinalizer(b, data, reduce, xla_axes);
auto result = keep_dims_ ? xla::Reshape(finalized, final_shape) : finalized;
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
index 8477046..8a8f33c 100644
--- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc
@@ -44,9 +44,6 @@
using xla::XlaOp;
-// TODO(b/112295522): note that sampling from image boundary is not currently
-// being handled properly.
-
// Calculates the bilinear weight tensor, given basis ratio (px, py) of the
// sampling position:
// W = [(1-px)*(1-py), px*(1-py), (1-px)*py, px*py]
@@ -421,12 +418,13 @@
OP_REQUIRES(ctx, warp_shape.dim_size(last_warp_dim) == 2,
errors::InvalidArgument(
"the last dimension of warp must be exactly size 2."));
+ xla::PrimitiveType warp_type = ctx->input_xla_type(1);
XlaOp data = ctx->Input("data");
XlaOp warp = ctx->Input("warp");
// Find the coordinates of the top left corner for the 2x2 region to be
- // sampled from. The dimensions are (batch, dim_0, ... dim_n, 2) where the
+ // sampled from. The dimensions are [batch, dim_0, ... dim_n, 2] where the
// last dimension of size 2 in turn is [x, y].
XlaOp top_left = xla::ConvertElementType(warp, xla::U32);
@@ -457,10 +455,56 @@
dot_dims.add_lhs_contracting_dimensions(warp_shape.dims() - 1);
dot_dims.add_rhs_contracting_dimensions(warp_shape.dims() - 1);
+ // The dimension is [batch, dim_0, ...dim_n, data_channels].
auto blended_pixels = xla::DotGeneral(weights, neighbors_data, dot_dims,
/*precision_config=*/nullptr);
- ctx->SetOutput(0, blended_pixels);
+ // Handle out of boundary cases by constructing a predicate mask array based
+ // on the in-bound condition, and output 0 for the blended pixel value if
+ // out-bound. The dimension is the same as top_left: [batch, dim_0,
+ // ...dim_n, 2] where the last dimension of size 2 is the [x, y] coordinate.
+
+ auto is_ge_zero = xla::Ge(warp, xla::ZerosLike(warp));
+
+ auto is_lt_image_size = xla::Lt(
+ warp,
+ xla::ConvertElementType(
+ xla::ConstantR1<float>(
+ ctx->builder(),
+ {/*width=*/static_cast<float>(data_shape.dim_size(2) - 1),
+ /*height=*/static_cast<float>(data_shape.dim_size(1) - 1)}),
+ warp_type),
+ /*broadcast_dimensions=*/{warp_shape.dims() - 1});
+
+ auto is_in_bound_x_y = xla::And(is_ge_zero, is_lt_image_size);
+ // Reduce along last dimension. The resulting dimension is:
+ // [batch, dim_0, ...dim_n].
+ auto is_in_bound = xla::Reduce(
+ is_in_bound_x_y, xla::ConstantR0<bool>(ctx->builder(), true),
+ xla::CreateScalarAndComputation(xla::PrimitiveType::PRED,
+ ctx->builder()),
+ {last_warp_dim});
+
+ // Broadcast 'is_in_bound' to the same dimension as 'blended_pixels', which
+ // is the dimension of the result:
+ // [batch, dim_0, ...dim_n, data_channels].
+ auto warp_dims = warp_shape.dim_sizes();
+ std::vector<int64> result_dims(warp_dims.begin(), warp_dims.end() - 1);
+ result_dims.push_back(data_channels);
+ xla::Shape broadcasted_shape =
+ xla::ShapeUtil::MakeShape(xla::PrimitiveType::PRED, result_dims);
+
+ std::vector<int64> broadcasted_dims(warp_dims.size() - 1);
+ std::iota(broadcasted_dims.begin(), broadcasted_dims.end(), 0);
+ auto broadcasted_is_in_bound =
+ xla::BroadcastInDim(is_in_bound, broadcasted_shape, broadcasted_dims);
+
+ // Set out of bound samples to zero.
+ auto zeros =
+ xla::Broadcast(xla::Zero(ctx->builder(), data_type), result_dims);
+ auto result = xla::Select(broadcasted_is_in_bound, blended_pixels, zeros);
+
+ ctx->SetOutput(0, result);
}
};
@@ -473,6 +517,8 @@
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype));
}
+ // TODO(b/112295522): note that sampling from image boundary is not currently
+ // being handled properly.
void Compile(XlaOpKernelContext* ctx) override {
TensorShape data_shape_tf = ctx->InputShape("data");
OP_REQUIRES(ctx, data_shape_tf.dims() == 4,
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index 6970dd0..e4046c7 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -47,8 +47,7 @@
// compilation.
OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input));
} else {
- XlaContext& xla_context = XlaContext::Get(ctx);
- xla_context.SetRetval(index_, ctx->InputExpression(0));
+ ctx->xla_context()->SetRetval(index_, ctx->InputExpression(0));
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
index 7ff3e91..d7b38e8 100644
--- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
@@ -18,7 +18,6 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index b5fd785..4b9e1a5 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -39,8 +39,8 @@
// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
// the type constraint.
-constexpr std::array<DataType, 3> kScanOpTypes = {
- {DT_HALF, DT_BFLOAT16, DT_FLOAT}};
+constexpr std::array<DataType, 4> kScanOpTypes = {
+ {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_INT32}};
class ScanOp : public XlaOpKernel {
public:
@@ -103,11 +103,10 @@
reducer = ctx->GetOrCreateMul(dtype);
}
auto output = xla::ReduceWindowWithGeneralPadding(
- XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init,
- *reducer, window_dims, window_strides,
+ XlaHelpers::ConvertElementType(ctx->Input(0), dtype), init, *reducer,
+ window_dims, window_strides,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding);
- output =
- XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0));
+ output = XlaHelpers::ConvertElementType(output, ctx->input_type(0));
// In exclusive mode, we have computed an extra element containing the sum
// of all the input elements. Slice off this extra "last" element.
diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
index a7f5a8f..84470b2 100644
--- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc
@@ -42,7 +42,7 @@
}
void SendOp::Compile(XlaOpKernelContext* ctx) {
- XlaCompiler* compiler = XlaContext::Get(ctx).compiler();
+ XlaCompiler* compiler = ctx->compiler();
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel));
xla::Send(ctx->Input(0), channel);
@@ -73,7 +73,7 @@
}
void RecvOp::Compile(XlaOpKernelContext* ctx) {
- XlaCompiler* compiler = XlaContext::Get(ctx).compiler();
+ XlaCompiler* compiler = ctx->compiler();
xla::ChannelHandle channel;
OP_REQUIRES_OK(ctx, compiler->GetChannelHandle(tensor_name_, &channel));
ctx->SetOutput(0, xla::Recv(ctx->builder(), shape_, channel));
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
index 60b011b..b1fa291 100644
--- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -18,7 +18,7 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/core/framework/op_kernel.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index d6bd927..20da803 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -71,7 +71,7 @@
auto reduce =
xla::Reduce(converted, xla::Zero(b, xla_accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
- auto sum = XlaHelpers::ConvertElementType(b, reduce, type);
+ auto sum = XlaHelpers::ConvertElementType(reduce, type);
auto softmax =
log_
// softmax = shifted_logits - log(sum(exp(shifted_logits)))
@@ -111,11 +111,11 @@
// sum_{class} (exp(logits - max_logits))
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
auto converted =
- XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type);
+ XlaHelpers::ConvertElementType(exp_shifted_logits, accumulation_type);
auto reduce =
xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
- auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type);
+ auto sum_exp = XlaHelpers::ConvertElementType(reduce, type);
// log(sum(exp(logits - max_logits)))
auto log_sum_exp = xla::Log(sum_exp);
@@ -126,11 +126,10 @@
// (The subtraction broadcasts along the batch dimension.)
auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim});
auto mul = xla::Mul(xla::Neg(labels), sub);
- auto sum =
- xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type),
- XlaHelpers::Zero(b, accumulation_type),
- *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
- auto loss = XlaHelpers::ConvertElementType(b, sum, type);
+ auto sum = xla::Reduce(XlaHelpers::ConvertElementType(mul, accumulation_type),
+ XlaHelpers::Zero(b, accumulation_type),
+ *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
+ auto loss = XlaHelpers::ConvertElementType(sum, type);
// backprop: prob - labels, where
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index 7b96b43..8e9e4da 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -69,7 +69,7 @@
}
TensorShape stack_shape;
- stack_shape.AddDim(resource->tensor_array_size());
+ stack_shape.AddDim(resource->max_array_size());
stack_shape.AppendShape(elem_shape);
if (!resource->initialized()) {
@@ -97,10 +97,10 @@
}
void Compile(XlaOpKernelContext* ctx) override {
- int64 size;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &size));
+ int64 max_size;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &max_size));
OP_REQUIRES(
- ctx, size >= 0,
+ ctx, max_size >= 0,
errors::InvalidArgument(
"XLA compilation requires a fixed stack size upper bound. If "
"you are using tf.while_loop, set the maximum_iterations parameter "
@@ -108,14 +108,9 @@
// We defer initializing the Stack resource until we see the first push.
// Otherwise we do not know the shape of the stack elements.
- xla::XlaOp value;
- XlaContext& xc = XlaContext::Get(ctx);
- XlaResource* resource;
- string name = absl::StrCat("Stack: ", stack_name_);
- OP_REQUIRES_OK(
- ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
- TensorShape(), value, /*tensor_array_size=*/size,
- /*tensor_array_gradients=*/{}, &resource));
+ XlaResource* resource =
+ ctx->xla_context()->AddResource(XlaResource::CreateStack(
+ /*name=*/absl::StrCat("Stack: ", stack_name_), dtype_, max_size));
ctx->SetResourceOutput(0, resource);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 252967a..939d7e1 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -61,8 +61,8 @@
" but op has dtype ", DataTypeString(dtype), ".");
}
- TF_RET_CHECK(resource->tensor_array_size() >= 0)
- << resource->name() << " size " << resource->tensor_array_size();
+ TF_RET_CHECK(resource->max_array_size() >= 0)
+ << resource->name() << " size " << resource->max_array_size();
if (!resource->initialized()) {
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
@@ -78,7 +78,7 @@
XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
TensorShape ta_shape;
- ta_shape.AddDim(resource->tensor_array_size());
+ ta_shape.AddDim(resource->max_array_size());
ta_shape.AppendShape(elem_shape);
if (ta_shape != shape) {
return errors::InvalidArgument(
@@ -114,7 +114,7 @@
Status GetTensorArrayShape(const XlaResource* resource,
xla::XlaBuilder* builder, TensorShape* shape) {
*shape = resource->shape();
- shape->InsertDim(0, resource->tensor_array_size());
+ shape->InsertDim(0, resource->max_array_size());
return Status::OK();
}
@@ -166,13 +166,10 @@
value = xla::Broadcast(zero, ta_shape.dim_sizes());
}
- XlaContext& xc = XlaContext::Get(ctx);
- XlaResource* var;
- string name = absl::StrCat("TensorArray: ", tensor_array_name_);
- OP_REQUIRES_OK(
- ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
- dtype_, shape, value, /*tensor_array_size=*/size,
- /*tensor_array_gradients=*/{}, &var));
+ XlaResource* var =
+ ctx->xla_context()->AddResource(XlaResource::CreateTensorArray(
+ /*name=*/absl::StrCat("TensorArray: ", tensor_array_name_), dtype_,
+ shape, /*initial_value=*/value, /*max_array_size=*/size));
ctx->SetResourceOutput(0, var);
Tensor flow(DT_FLOAT, TensorShape({}));
@@ -517,14 +514,13 @@
xla::XlaOp ta = resource->value();
TensorShape ta_shape;
- ta_shape.AddDim(resource->tensor_array_size());
+ ta_shape.AddDim(resource->max_array_size());
ta_shape.AppendShape(elem_shape);
- OP_REQUIRES(
- ctx, lengths.size() == resource->tensor_array_size(),
- errors::InvalidArgument(
- "TensorArray's size is not equal to the size of lengths (",
- lengths.size(), " vs. ", resource->tensor_array_size(), ")"));
+ OP_REQUIRES(ctx, lengths.size() == resource->max_array_size(),
+ errors::InvalidArgument(
+ "TensorArray's size is not equal to the size of lengths (",
+ lengths.size(), " vs. ", resource->max_array_size(), ")"));
const xla::XlaOp value = ctx->Input(1);
const xla::XlaOp flow = ctx->Input(3);
@@ -562,8 +558,7 @@
XlaResource* var;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var));
Tensor size_tensor(DT_INT32, {});
- size_tensor.scalar<int32>()() =
- static_cast<int32>(var->tensor_array_size());
+ size_tensor.scalar<int32>()() = static_cast<int32>(var->max_array_size());
ctx->SetConstantOutput(0, size_tensor);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 7077c2e..960c146 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -320,9 +320,8 @@
xla::XlaOp lr = ctx->Input(4);
xla::XlaOp l1 = ctx->Input(5);
xla::XlaOp l2 = ctx->Input(6);
- xla::XlaBuilder* const b = ctx->builder();
xla::XlaOp global_step =
- XlaHelpers::ConvertElementType(b, ctx->Input(7), dtype_);
+ XlaHelpers::ConvertElementType(ctx->Input(7), dtype_);
accum = accum + grad;
squared_accum = squared_accum + xla::Square(grad);
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 559414e..ce007fc 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -64,7 +64,7 @@
if (!arg.initialized) {
*has_uninitialized_vars = true;
}
- arg.tensor_array_size = resource->tensor_array_size();
+ arg.max_array_size = resource->max_array_size();
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD
index c9f486e..fef97b9 100644
--- a/tensorflow/compiler/tf2xla/python/BUILD
+++ b/tensorflow/compiler/tf2xla/python/BUILD
@@ -1,11 +1,13 @@
licenses(["notice"]) # Apache 2.0
+package_group(
+ name = "friends",
+ includes = ["//tensorflow:internal"],
+)
+
package(
default_visibility = [
- "//learning/deepmind/public/wavenet/python:__subpackages__",
- "//learning/deepmind/research/alphastar:__subpackages__",
- "//learning/tfx:__subpackages__",
- "//tensorflow:internal",
+ ":friends",
],
)
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index 6620690..a1d359e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -26,7 +26,7 @@
// Forward-declare, rather than include, to reduce code size for users that
// never use this functionality.
namespace xla {
-class ProgramShape;
+class ProgramShapeProto;
class HloProfilePrinterData;
}
@@ -84,7 +84,7 @@
void set_result_names(const char** result_names) {
result_names_ = result_names;
}
- void set_program_shape(const xla::ProgramShape* program_shape) {
+ void set_program_shape(const xla::ProgramShapeProto* program_shape) {
program_shape_ = program_shape;
}
const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
@@ -122,7 +122,7 @@
const char** result_names_ = nullptr;
// [Optional] Arg and result shapes.
- const xla::ProgramShape* program_shape_ = nullptr;
+ const xla::ProgramShapeProto* program_shape_ = nullptr;
// [Optional] Profile printer data. Null if profiling is disabled.
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
@@ -264,7 +264,7 @@
// Returns the shape of the args and results. May return nullptr if the
// program shape isn't available.
- const xla::ProgramShape* ProgramShape() const { return program_shape_; }
+ const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; }
bool hlo_profiling_enabled() const {
return hlo_profile_printer_data_ != nullptr;
@@ -305,7 +305,7 @@
// Optional metadata.
const char** arg_names_ = nullptr;
const char** result_names_ = nullptr;
- const xla::ProgramShape* program_shape_ = nullptr;
+ const xla::ProgramShapeProto* program_shape_ = nullptr;
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 8036bc6..ee461a3 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -326,10 +326,10 @@
bool XlaCompiler::Argument::operator==(
const XlaCompiler::Argument& other) const {
- if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
+ if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
tensor_array_gradients) !=
std::tie(other.kind, other.resource_kind, other.type, other.name,
- other.initialized, other.tensor_array_size,
+ other.initialized, other.max_array_size,
other.tensor_array_gradients)) {
return false;
}
@@ -359,8 +359,8 @@
string output = absl::StrCat("kind=resource", common, " resource_kind=",
XlaResource::KindToString(resource_kind),
" initialized=", initialized);
- if (tensor_array_size >= 0) {
- absl::StrAppend(&output, " tensor_array_size=", tensor_array_size);
+ if (max_array_size >= 0) {
+ absl::StrAppend(&output, " max_array_size=", max_array_size);
}
if (!tensor_array_gradients.empty()) {
absl::StrAppend(&output, " tensor_array_gradients=",
@@ -380,7 +380,7 @@
initialization_status_(Status::OK()),
next_step_id_(1),
device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)),
- device_mgr_({device_}) {
+ device_mgr_(absl::WrapUnique(device_)) {
CHECK(!options_.device_type.type_string().empty());
if (options_.populate_resource_manager) {
initialization_status_ =
@@ -567,12 +567,12 @@
return Status::OK();
}
case XlaResource::kTensorArray: {
- if (arg.tensor_array_size < 0) {
+ if (arg.max_array_size < 0) {
return errors::InvalidArgument(
- "Negative tensor_array_size in XLAShapeForArgument");
+ "Negative max_array_size in XLAShapeForArgument");
}
TensorShape shape;
- shape.AddDim(arg.tensor_array_size);
+ shape.AddDim(arg.max_array_size);
shape.AppendShape(arg.shape);
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
@@ -584,12 +584,12 @@
return Status::OK();
}
case XlaResource::kStack: {
- if (arg.tensor_array_size < 0) {
+ if (arg.max_array_size < 0) {
return errors::InvalidArgument(
- "Negative tensor_array_size in XLAShapeForArgument");
+ "Negative max_array_size in XLAShapeForArgument");
}
TensorShape shape;
- shape.AddDim(arg.tensor_array_size);
+ shape.AddDim(arg.max_array_size);
shape.AppendShape(arg.shape);
xla::Shape buffer_shape;
TF_RETURN_IF_ERROR(
@@ -635,21 +635,23 @@
const XlaCompiler::Argument& arg = args[i];
XlaExpression& arg_expression = (*arg_expressions)[i];
switch (arg.kind) {
- case XlaCompiler::Argument::kResource:
+ case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
// TODO(phawkins): this code assumes that resource arguments do not
// alias.
- XlaResource* resource;
- TF_RETURN_IF_ERROR(context->CreateResource(
- arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(),
- /*tensor_array_size=*/arg.tensor_array_size,
- /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
+ XlaResource* resource =
+ context->AddResource(absl::make_unique<XlaResource>(
+ arg.resource_kind, i, arg.name, arg.type, arg.shape,
+ xla::XlaOp(),
+ /*max_array_size=*/arg.max_array_size,
+ /*tensor_array_gradients=*/arg.tensor_array_gradients,
+ /*tensor_array_multiple_writes_aggregate=*/true));
arg_expression = XlaExpression::Resource(resource);
if (arg.initialized) {
input_mapping->push_back(i);
}
-
break;
+ }
case XlaCompiler::Argument::kParameter:
case XlaCompiler::Argument::kToken: {
input_mapping->push_back(i);
@@ -923,9 +925,7 @@
options_.device_type, name));
xla::XlaBuilder builder(name);
- XlaContext* context =
- new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
- &options_.shape_representation_fn);
+ XlaContext* context = new XlaContext(this, &builder);
core::ScopedUnref context_unref(context);
std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 6342612..0d801b7 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -150,7 +150,7 @@
// For a TensorArray or Stack resource, what is the array's declared size?
// (Used for lazy initialization.)
- int64 tensor_array_size = -1;
+ int64 max_array_size = -1;
// TensorArray resource parameters are passed as (array, gradient array 0,
// ..., gradient array k), where the gradient arrays are in the same order
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index eba5d77..fe2a5f5 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -650,7 +650,7 @@
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
- args[0].tensor_array_size = 2;
+ args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad2"};
// Compiles the graph.
@@ -709,7 +709,7 @@
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
- args[0].tensor_array_size = 2;
+ args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
// Compiles the graph.
@@ -741,7 +741,7 @@
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
- args[0].tensor_array_size = 2;
+ args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
// Compiles the graph.
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 43095fb..a69af70 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -54,25 +54,14 @@
return *context;
}
-/* static */ XlaContext& XlaContext::Get(const XlaOpKernelContext* ctx) {
- return Get(ctx->op_kernel_context());
-}
-
void XlaContext::set_args(std::vector<XlaExpression> args) {
args_ = std::move(args);
}
-XlaContext::XlaContext(
- XlaCompiler* compiler, xla::XlaBuilder* builder,
- bool allow_cpu_custom_calls,
- const std::function<xla::StatusOr<xla::Shape>(
- const TensorShape&, DataType)>* shape_representation_fn)
- : compiler_(compiler),
- builder_(builder),
- allow_cpu_custom_calls_(allow_cpu_custom_calls),
- shape_representation_fn_(shape_representation_fn) {}
+XlaContext::XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder)
+ : compiler_(compiler), builder_(builder) {}
-string XlaContext::DebugString() { return "TLA JIT context"; }
+string XlaContext::DebugString() { return "XLA JIT context"; }
void XlaContext::SetRetval(int index, const XlaExpression& expression) {
if (retvals_.size() <= index) {
@@ -81,21 +70,9 @@
retvals_[index] = expression;
}
-Status XlaContext::CreateResource(
- XlaResource::Kind kind, int arg_num, string name, DataType type,
- TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size,
- const std::set<string>& tensor_array_gradients, XlaResource** resource) {
- resources_.emplace_back(
- new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
- handle, tensor_array_size, tensor_array_gradients,
- /*tensor_array_multiple_writes_aggregate=*/false));
- *resource = resources_.back().get();
- return Status::OK();
-}
-
-xla::StatusOr<xla::Shape> XlaContext::RepresentationShape(
- const TensorShape& shape, DataType type) const {
- return (*shape_representation_fn_)(shape, type);
+XlaResource* XlaContext::AddResource(std::unique_ptr<XlaResource> resource) {
+ resources_.push_back(std::move(resource));
+ return resources_.back().get();
}
const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index dbfd344..0767d1f 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -41,14 +41,10 @@
public:
// Retrieves the XlaContext of the current compilation.
static XlaContext& Get(const OpKernelContext* ctx);
- static XlaContext& Get(const XlaOpKernelContext* ctx);
// Creates a new XlaContext. See the documentation on the class data fields
// for descriptions of the arguments.
- XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
- bool allow_cpu_custom_calls,
- const std::function<xla::StatusOr<xla::Shape>(
- const TensorShape&, DataType)>* shape_representation_fn);
+ XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder);
// Virtual method defined by ResourceBase.
string DebugString() override;
@@ -58,8 +54,6 @@
// Returns the XlaBuilder that Ops use for compiling new expressions.
xla::XlaBuilder* builder() { return builder_; }
- bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
-
const std::vector<XlaExpression>& args() const { return args_; }
void set_args(std::vector<XlaExpression> args);
@@ -70,25 +64,13 @@
// grows the return values vector to size index+1 if it is smaller.
void SetRetval(int index, const XlaExpression& expression);
- // Creates a resource with resource `kind` and initial value `handle`. `name`
- // is a descriptive name for use in error messages. See the `XlaResource`
- // constructor for a description of the remaining arguments.
- // Fails if the resource already exists.
- Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
- DataType type, TensorShape shape,
- const xla::XlaOp& handle, int64 tensor_array_size,
- const std::set<string>& tensor_array_gradients,
- XlaResource** resource);
+ // Adds 'resource' to the set of resources owned by the context.
+ XlaResource* AddResource(std::unique_ptr<XlaResource> resource);
const std::vector<std::unique_ptr<XlaResource>>& resources() {
return resources_;
}
- // Returns the XLA shape to be used to represent a variable of TF `shape`
- // and `type`, or of an argument or return value of a top-level computation.
- xla::StatusOr<xla::Shape> RepresentationShape(const TensorShape& shape,
- DataType type) const;
-
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
@@ -118,9 +100,6 @@
// The XlaBuilder used to construct the subgraph's compiled representation.
xla::XlaBuilder* builder_;
- // Allow ops to emit CustomCall operations for CPU.
- const bool allow_cpu_custom_calls_;
-
// Arguments to the Tensorflow graph, indexed by _Arg index.
// Includes both compile-time constant arguments and runtime parameters.
std::vector<XlaExpression> args_;
@@ -131,11 +110,6 @@
// Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_;
- // Describes the on-host shapes of parameters and return values. Also see:
- // XlaDevice::Options::shape_representation_fn.
- const std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>*
- shape_representation_fn_;
-
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::XlaComputation>;
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 9a34cd8..c2c0751 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -26,7 +26,6 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/types.h"
@@ -216,8 +215,7 @@
return dtype;
}
-xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder,
- const xla::XlaOp& operand,
+xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
const DataType new_element_type) {
xla::PrimitiveType convert_to;
TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
index 3957814..4858dfe 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.h
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -80,8 +80,7 @@
// A helper for creating a ConvertElementType xla op given a DataType rather
// than the xla::PrimitiveType.
- static xla::XlaOp ConvertElementType(xla::XlaBuilder* const builder,
- const xla::XlaOp& operand,
+ static xla::XlaOp ConvertElementType(const xla::XlaOp& operand,
const DataType new_element_type);
};
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 86a78ee..fabbcd0 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -133,7 +133,8 @@
jit->executable_ = std::move(executable);
jit->buffer_infos_ = std::move(buffer_infos);
jit->arg_index_table_ = std::move(arg_index_table);
- jit->program_shape_ = std::move(program_shape);
+ jit->program_shape_ =
+ absl::make_unique<xla::ProgramShapeProto>(program_shape->ToProto());
jit->static_data_.set_raw_function(raw_function);
jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
index d3c8f22..a539205 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
@@ -80,8 +80,10 @@
std::vector<const char*> arg_names_;
std::vector<const char*> result_names_;
- // The backing data for the program shape.
- std::unique_ptr<const xla::ProgramShape> program_shape_;
+ // The backing data for the program shape. The proto form of program shape is
+ // used because the program shape is serialized and embedded in the object
+ // file.
+ std::unique_ptr<const xla::ProgramShapeProto> program_shape_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
index 6d49298..4496255 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
@@ -116,7 +116,7 @@
// Check program shape.
using xla::ShapeUtil;
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
- const xla::ProgramShape* program_shape = function.ProgramShape();
+ const xla::ProgramShapeProto* program_shape = function.ProgramShape();
ASSERT_TRUE(program_shape != nullptr);
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 8dd8def..58808c7 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -36,8 +36,16 @@
return context_->ValidateInputsAreSameShape(op);
}
+XlaContext* XlaOpKernelContext::xla_context() const {
+ return &XlaContext::Get(context_);
+}
+
xla::XlaBuilder* XlaOpKernelContext::builder() const {
- return XlaContext::Get(this).builder();
+ return xla_context()->builder();
+}
+
+XlaCompiler* XlaOpKernelContext::compiler() const {
+ return xla_context()->compiler();
}
// Retrieves an XlaExpression that was allocated by a previous Op.
@@ -338,8 +346,8 @@
namespace {
Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
- const OpKernelContext* ctx, TensorShape* shape,
- xla::XlaOp* value) {
+ const XlaOpKernelContext* ctx,
+ TensorShape* shape, xla::XlaOp* value) {
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
@@ -357,10 +365,9 @@
*shape = variable->shape();
}
- XlaContext& xla_context = XlaContext::Get(ctx);
- TF_ASSIGN_OR_RETURN(
- xla::Shape representation_shape,
- xla_context.RepresentationShape(variable->shape(), variable->type()));
+ TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
+ ctx->compiler()->options().shape_representation_fn(
+ variable->shape(), variable->type()));
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
@@ -377,15 +384,15 @@
Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
TensorShape* shape,
xla::XlaOp* value) {
- return ReadVariableInputTensor(context_->input(index), type, context_, shape,
+ return ReadVariableInputTensor(context_->input(index), type, this, shape,
value);
}
Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
DataType type, TensorShape* shape,
xla::XlaOp* value) {
- return ReadVariableInputTensor(GetInputTensorByName(name), type, context_,
- shape, value);
+ return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
+ value);
}
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
@@ -464,7 +471,7 @@
namespace {
Status AssignVariableTensor(const Tensor& tensor, DataType type,
- const OpKernelContext* ctx, xla::XlaOp handle,
+ const XlaOpKernelContext* ctx, xla::XlaOp handle,
xla::XlaBuilder* builder) {
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
@@ -481,9 +488,9 @@
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
- XlaContext& xla_context = XlaContext::Get(ctx);
- TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
- xla_context.RepresentationShape(shape, type));
+ TF_ASSIGN_OR_RETURN(
+ xla::Shape representation_shape,
+ ctx->compiler()->options().shape_representation_fn(shape, type));
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
@@ -498,19 +505,15 @@
Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
xla::XlaOp handle) {
TF_RET_CHECK(handle.valid());
- return AssignVariableTensor(context_->input(input_index), type, context_,
- handle, builder());
+ return AssignVariableTensor(context_->input(input_index), type, this, handle,
+ builder());
}
Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
xla::XlaOp handle) {
TF_RET_CHECK(handle.valid());
- return AssignVariableTensor(GetInputTensorByName(name), type, context_,
- handle, builder());
-}
-
-XlaCompiler* XlaOpKernelContext::compiler() const {
- return XlaContext::Get(context_).compiler();
+ return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
+ builder());
}
void XlaOpKernelContext::CtxFailure(const Status& s) {
@@ -530,22 +533,22 @@
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
const DataType type) {
- return XlaContext::Get(context_).GetOrCreateMax(type);
+ return xla_context()->GetOrCreateMax(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
const DataType type) {
- return XlaContext::Get(context_).GetOrCreateMin(type);
+ return xla_context()->GetOrCreateMin(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
const DataType type) {
- return XlaContext::Get(context_).GetOrCreateAdd(type);
+ return xla_context()->GetOrCreateAdd(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
const DataType type) {
- return XlaContext::Get(context_).GetOrCreateMul(type);
+ return xla_context()->GetOrCreateMul(type);
}
const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index c06efa2..1858844 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -60,6 +60,8 @@
public:
explicit XlaOpKernelContext(OpKernelContext* context);
+ XlaContext* xla_context() const;
+
// Returns the XLA XlaBuilder containing the output of compilation.
xla::XlaBuilder* builder() const;
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index a322eb9..48a3c01 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -18,6 +18,7 @@
#include <functional>
#include <memory>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/sharding_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
@@ -39,9 +40,29 @@
}
}
+/*static*/ std::unique_ptr<XlaResource> XlaResource::CreateStack(
+ string name, DataType type, int64 max_size) {
+ return absl::make_unique<XlaResource>(
+ XlaResource::kStack, /*arg_num=*/-1, std::move(name), type, TensorShape(),
+ /*initial_value=*/xla::XlaOp(),
+ /*max_array_size=*/max_size,
+ /*tensor_array_gradients=*/std::set<string>{},
+ /*tensor_array_multiple_writes_aggregate=*/false);
+}
+
+/*static*/ std::unique_ptr<XlaResource> XlaResource::CreateTensorArray(
+ string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
+ int64 max_array_size) {
+ return absl::make_unique<XlaResource>(
+ XlaResource::kTensorArray, /*arg_num=*/-1, std::move(name), type, shape,
+ initial_value, max_array_size,
+ /*tensor_array_gradients=*/std::set<string>{},
+ /*tensor_array_multiple_writes_aggregate=*/false);
+}
+
XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
TensorShape shape, const xla::XlaOp& initial_value,
- int64 tensor_array_size,
+ int64 max_array_size,
const std::set<string>& tensor_array_gradients,
bool tensor_array_multiple_writes_aggregate)
: kind_(kind),
@@ -51,7 +72,7 @@
shape_(std::move(shape)),
value_(initial_value),
initial_value_(initial_value),
- tensor_array_size_(tensor_array_size),
+ max_array_size_(max_array_size),
tensor_array_multiple_writes_aggregate_(
tensor_array_multiple_writes_aggregate) {
CHECK(kind_ != kInvalid);
@@ -60,7 +81,7 @@
tensor_array_gradients_[gradient].reset(new XlaResource(
/*kind=*/kTensorArray, /*arg_num=*/-1,
/*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_,
- xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{},
+ xla::XlaOp(), max_array_size_, /*tensor_array_gradients=*/{},
/*tensor_array_multiple_writes_aggregate=*/true));
}
}
@@ -113,7 +134,7 @@
}
case kTensorArray: {
TensorShape ta_shape;
- ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AddDim(max_array_size_);
ta_shape.AppendShape(shape_);
value_ = xla::Broadcast(XlaHelpers::Zero(builder, type_),
ta_shape.dim_sizes());
@@ -121,7 +142,7 @@
}
case kStack: {
TensorShape ta_shape;
- ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AddDim(max_array_size_);
ta_shape.AppendShape(shape_);
value_ =
xla::Tuple(builder, {xla::Broadcast(XlaHelpers::Zero(builder, type_),
@@ -146,14 +167,14 @@
std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
if (!gradient) {
TensorShape ta_shape;
- ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AddDim(max_array_size_);
ta_shape.AppendShape(shape_);
xla::XlaOp gradient_value =
xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
gradient.reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
/*name=*/absl::StrCat("TensorArrayGrad: ", name_),
- type_, shape_, gradient_value, tensor_array_size_,
+ type_, shape_, gradient_value, max_array_size_,
/*tensor_array_gradients=*/{},
/*tensor_array_multiple_writes_aggregate=*/true));
}
diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h
index 857b9a9..736588b 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.h
+++ b/tensorflow/compiler/tf2xla/xla_resource.h
@@ -38,9 +38,18 @@
};
static absl::string_view KindToString(Kind kind);
+ // Creates a new Stack resource.
+ static std::unique_ptr<XlaResource> CreateStack(string name, DataType type,
+ int64 max_size);
+
+ // Creates a new TensorArray resource.
+ static std::unique_ptr<XlaResource> CreateTensorArray(
+ string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
+ int64 max_array_size);
+
XlaResource(Kind kind, int arg_num, string name, DataType type,
TensorShape shape, const xla::XlaOp& initial_value,
- int64 tensor_array_size,
+ int64 max_array_size,
const std::set<string>& tensor_array_gradients,
bool tensor_array_multiple_writes_aggregate);
@@ -119,12 +128,12 @@
// TODO(phawkins): refactor this code to use subclasses, rather than putting
// kind-specific fields in XlaResource.
- // 'tensor_array_size' stores the expected size of the TensorArray or Stack.
+ // 'max_array_size' stores the expected size of the TensorArray or Stack.
// We need to store this since sometimes TensorArrays must be initialized
// lazily since we do not know the element shape at construction time.
// Used by both TensorArrays and Stacks.
- int64 tensor_array_size() const { return tensor_array_size_; }
- void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
+ int64 max_array_size() const { return max_array_size_; }
+ void set_max_array_size(int64 size) { max_array_size_ = size; }
bool tensor_array_multiple_writes_aggregate() const {
return tensor_array_multiple_writes_aggregate_;
@@ -151,7 +160,7 @@
xla::XlaOp value_;
xla::XlaOp initial_value_;
- int64 tensor_array_size_ = -1;
+ int64 max_array_size_ = -1;
bool tensor_array_multiple_writes_aggregate_ = false;
std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index d914e97..4360e08 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -226,12 +226,14 @@
"index_util.cc",
"layout_util.cc",
"primitive_util.cc",
+ "shape.cc",
"shape_util.cc",
],
hdrs = [
"index_util.h",
"layout_util.h",
"primitive_util.h",
+ "shape.h",
"shape_util.h",
],
visibility = ["//visibility:public"],
@@ -255,6 +257,23 @@
)
tf_cc_test(
+ name = "shape_test",
+ srcs = ["shape_test.cc"],
+ deps = [
+ ":shape_util",
+ ":status_macros",
+ ":test",
+ ":test_helpers",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cc_test(
name = "shape_util_test",
srcs = ["shape_util_test.cc"],
deps = [
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index 42da0eb..ad2e525 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -90,11 +90,12 @@
srcs = ["executable_build_options.cc"],
hdrs = ["executable_build_options.h"],
deps = [
+ "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
- "//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
@@ -191,6 +192,7 @@
hdrs = ["xla_computation.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc
index 0f17453..1f594e5 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.cc
+++ b/tensorflow/compiler/xla/client/executable_build_options.cc
@@ -16,6 +16,7 @@
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "absl/strings/str_format.h"
+#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
@@ -39,6 +40,13 @@
int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
+DebugOptions* ExecutableBuildOptions::mutable_debug_options() {
+ if (!has_debug_options()) {
+ debug_options_ = GetDebugOptionsFromFlags();
+ }
+ return &debug_options_.value();
+}
+
ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
const Shape& shape_with_layout) {
result_layout_set_ = true;
@@ -55,68 +63,10 @@
if (result_layout_set_) {
result_layout = ShapeUtil::HumanStringWithLayout(result_layout_);
}
- string generate_hlo_graph = "nullopt";
- if (generate_hlo_graph_.has_value()) {
- generate_hlo_graph = generate_hlo_graph_.value();
- }
return absl::StrFormat(
"ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, "
"generate_hlo_graph=%s}",
- device_ordinal_, result_layout, generate_hlo_graph);
-}
-
-ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph(
- string regex) {
- generate_hlo_graph_ = std::move(regex);
- return *this;
-}
-
-const absl::optional<string>& ExecutableBuildOptions::generate_hlo_graph()
- const {
- return generate_hlo_graph_;
-}
-
-ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to(
- absl::string_view dirpath) {
- dump_optimized_hlo_proto_to_ = string(dirpath);
- return *this;
-}
-
-const absl::optional<string>&
-ExecutableBuildOptions::dump_optimized_hlo_proto_to() const {
- return dump_optimized_hlo_proto_to_;
-}
-
-ExecutableBuildOptions&
-ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to(
- absl::string_view dirpath) {
- dump_unoptimized_hlo_proto_to_ = string(dirpath);
- return *this;
-}
-
-const absl::optional<string>&
-ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const {
- return dump_unoptimized_hlo_proto_to_;
-}
-
-ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to(
- absl::string_view dirpath) {
- dump_per_pass_hlo_proto_to_ = string(dirpath);
- return *this;
-}
-
-const absl::optional<string>&
-ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const {
- return dump_per_pass_hlo_proto_to_;
-}
-
-ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) {
- hlo_profile_ = enabled;
- return *this;
-}
-
-absl::optional<bool> ExecutableBuildOptions::hlo_profile() const {
- return hlo_profile_;
+ device_ordinal_, result_layout, debug_options().xla_generate_hlo_graph());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h
index 93334db..dd8cb55 100644
--- a/tensorflow/compiler/xla/client/executable_build_options.h
+++ b/tensorflow/compiler/xla/client/executable_build_options.h
@@ -20,6 +20,7 @@
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@@ -44,6 +45,12 @@
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
const Shape* result_layout() const;
+ // Expose access to the XLA debug options which will be passed to the
+ // compilation process.
+ bool has_debug_options() const { return debug_options_.has_value(); }
+ const DebugOptions& debug_options() const { return *debug_options_; }
+ DebugOptions* mutable_debug_options();
+
// If set, this specifies an allocator that can be used to allocate temporary
// space on the device during compilation. For example, the compiler might
// want to run various algorithms on the device and pick the fastest one -- it
@@ -55,56 +62,16 @@
DeviceMemoryAllocator* allocator);
DeviceMemoryAllocator* device_allocator() const;
- // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions).
- ExecutableBuildOptions& set_generate_hlo_graph(string regex);
- const absl::optional<string>& generate_hlo_graph() const;
-
- // If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO
- // protobuf to (as in DebugOptions).
- ExecutableBuildOptions& set_dump_optimized_hlo_proto_to(
- absl::string_view dirpath);
- const absl::optional<string>& dump_optimized_hlo_proto_to() const;
-
- // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO
- // protobuf to (as in DebugOptions).
- ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to(
- absl::string_view dirpath);
- const absl::optional<string>& dump_unoptimized_hlo_proto_to() const;
-
- // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs
- // to (as in DebugOptions).
- ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to(
- absl::string_view dirpath);
- const absl::optional<string>& dump_per_pass_hlo_proto_to() const;
-
- // If true, specifies that we should record an HLO profile during execution
- // and log it after execution (as in DebugOptions). If nullopt the default is
- // used.
- ExecutableBuildOptions& set_hlo_profile(bool enabled);
- absl::optional<bool> hlo_profile() const;
-
- void add_disabled_hlo_pass(absl::string_view pass_name) {
- disabled_hlo_passes_.push_back(std::string(pass_name));
- }
- const absl::Span<const std::string> disabled_hlo_passes() const {
- return disabled_hlo_passes_;
- }
-
// Returns a string representation of the build options, suitable for
// debugging.
string ToString() const;
private:
- absl::optional<bool> hlo_profile_;
int device_ordinal_ = -1;
Shape result_layout_;
bool result_layout_set_ = false;
- absl::optional<string> generate_hlo_graph_;
- absl::optional<string> dump_optimized_hlo_proto_to_;
- absl::optional<string> dump_unoptimized_hlo_proto_to_;
- absl::optional<string> dump_per_pass_hlo_proto_to_;
+ absl::optional<DebugOptions> debug_options_;
DeviceMemoryAllocator* device_allocator_ = nullptr;
- std::vector<std::string> disabled_hlo_passes_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index f833ddc..c5733bc 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -164,7 +164,6 @@
deps = [
":constants",
":math",
- ":numeric",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
@@ -178,8 +177,9 @@
srcs = ["sorting.cc"],
hdrs = ["sorting.h"],
deps = [
- ":numeric",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
],
@@ -188,10 +188,6 @@
xla_test(
name = "sorting_test",
srcs = ["sorting_test.cc"],
- blacklisted_backends = [
- "cpu",
- "gpu",
- ],
tags = ["enable_for_xla_interpreter"],
deps = [
":sorting",
diff --git a/tensorflow/compiler/xla/client/lib/numeric.h b/tensorflow/compiler/xla/client/lib/numeric.h
index efd8cdc..f62fdab 100644
--- a/tensorflow/compiler/xla/client/lib/numeric.h
+++ b/tensorflow/compiler/xla/client/lib/numeric.h
@@ -22,9 +22,6 @@
namespace xla {
-// Returns a rank 1 tensor of `type` containing values [0, 1, 2, ...].
-XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
-
// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere
// else.
XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc
index c6f68c8..85b9e18 100644
--- a/tensorflow/compiler/xla/client/lib/prng.cc
+++ b/tensorflow/compiler/xla/client/lib/prng.cc
@@ -18,7 +18,6 @@
#include "absl/base/casts.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/util.h"
diff --git a/tensorflow/compiler/xla/client/lib/sorting.cc b/tensorflow/compiler/xla/client/lib/sorting.cc
index 0475fd9..e8553a0 100644
--- a/tensorflow/compiler/xla/client/lib/sorting.cc
+++ b/tensorflow/compiler/xla/client/lib/sorting.cc
@@ -14,7 +14,9 @@
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/sorting.h"
-#include "tensorflow/compiler/xla/client/lib/numeric.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -23,13 +25,12 @@
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input));
int last_dim = input_shape.dimensions_size() - 1;
- int last_dim_size = input_shape.dimensions(last_dim);
- XlaOp iota_s32 = Iota(builder, S32, last_dim_size);
+ Shape iota_shape =
+ ShapeUtil::MakeShape(S32, AsInt64Slice(input_shape.dimensions()));
+ XlaOp iota_s32 = Iota(builder, iota_shape, last_dim);
auto input_dims = input_shape.dimensions();
- std::vector<int64> broadcast_dims(input_dims.begin(), input_dims.end() - 1);
- XlaOp broadcast_s32 = Broadcast(iota_s32, broadcast_dims);
- XlaOp sort_result = Sort(Neg(input), {broadcast_s32});
+ XlaOp sort_result = Sort(Neg(input), {iota_s32});
std::vector<int64> start_indices(input_shape.dimensions_size(), 0);
std::vector<int64> limit_indices(input_dims.begin(), input_dims.end());
limit_indices[last_dim] = k;
diff --git a/tensorflow/compiler/xla/client/lib/sorting_test.cc b/tensorflow/compiler/xla/client/lib/sorting_test.cc
index fef98c9..27ff36c 100644
--- a/tensorflow/compiler/xla/client/lib/sorting_test.cc
+++ b/tensorflow/compiler/xla/client/lib/sorting_test.cc
@@ -14,6 +14,9 @@
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/sorting.h"
+
+#include <limits>
+
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -41,6 +44,28 @@
ComputeAndCompareR1<int>(&builder, {0, 1, 2}, {});
}
+// TODO(b/119930279): enable this test.
+XLA_TEST_F(SortingTest, DISABLED_TopKFullSortMinInt) {
+ XlaBuilder builder(TestName());
+ auto x_rev = ConstantR1<int>(&builder, {std::numeric_limits<int>::min(),
+ std::numeric_limits<int>::min() + 1,
+ std::numeric_limits<int>::max()});
+ xla::GetTupleElement(xla::TopK(x_rev, 3), 1);
+ ComputeAndCompareR1<int>(&builder, {2, 1, 0}, {});
+}
+
+XLA_TEST_F(SortingTest, NOT_TopKFullSortMinInt) {
+ XlaBuilder builder(TestName());
+ auto x_rev = ConstantR1<int>(&builder, {std::numeric_limits<int>::min(),
+ std::numeric_limits<int>::min() + 1,
+ std::numeric_limits<int>::max()});
+ xla::GetTupleElement(xla::TopK(x_rev, 3), 1);
+ // TopK currently negates the keys, which doesn't work correctly for
+ // std::numeric_limits<int>::min(). Therefore, it will sort this key to the
+ // front instead of to the back.
+ ComputeAndCompareR1<int>(&builder, {0, 2, 1}, {});
+}
+
XLA_TEST_F(SortingTest, TopKFullSort) {
XlaBuilder builder(TestName());
const int kSize = 16;
@@ -56,5 +81,13 @@
ComputeAndCompareR1<float>(&builder, inputs, {});
}
+XLA_TEST_F(SortingTest, TopKFullSortWithDuplicates) {
+ XlaBuilder builder(TestName());
+ XlaOp a;
+ auto a_data = CreateR1Parameter<int>({1, 1, 2, 2, 1}, 0, "a", &builder, &a);
+ xla::GetTupleElement(xla::TopK(a, 5), 1);
+ ComputeAndCompareR1<int>(&builder, {2, 3, 0, 1, 4}, {a_data.get()});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index f508ffb..f17bc45 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -288,7 +288,8 @@
HloComputationProto entry;
SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
- TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id));
+ TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
+ *entry.mutable_program_shape() = program_shape.ToProto();
entry.set_root_id(root_id);
for (auto& instruction : instructions_) {
@@ -1319,6 +1320,15 @@
if (tokens.empty()) {
return InvalidArgument("AfterAll requires at least one operand");
}
+ for (int i = 0; i < tokens.size(); ++i) {
+ const XlaOp& operand = tokens[i];
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ if (!ShapeUtil::IsToken(operand_shape)) {
+ return InvalidArgument(
+ "All operands to AfterAll must be tokens; operand %d has shape %s",
+ i, ShapeUtil::HumanString(operand_shape));
+ }
+ }
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
@@ -2372,7 +2382,7 @@
SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
GetNextId());
entry.set_root_id(root->id());
- ProgramShape* program_shape = entry.mutable_program_shape();
+ ProgramShapeProto* program_shape = entry.mutable_program_shape();
*program_shape->mutable_result() = root->shape();
// We use std::set to keep the instruction ids in ascending order (which is
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index 8aa85c3..e534fb6 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -446,5 +446,14 @@
EXPECT_EQ(c0_string, c1_string);
}
+TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
+ XlaBuilder b(TestName());
+ AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
+ Status status = b.Build().status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("All operands to AfterAll must be tokens"));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc
index c9870b6..f317892 100644
--- a/tensorflow/compiler/xla/client/xla_computation.cc
+++ b/tensorflow/compiler/xla/client/xla_computation.cc
@@ -25,7 +25,7 @@
StatusOr<ProgramShape> XlaComputation::GetProgramShape() const {
TF_RET_CHECK(proto_.has_host_program_shape());
- return proto_.host_program_shape();
+ return ProgramShape(proto_.host_program_shape());
}
StatusOr<std::unique_ptr<HloSnapshot>> XlaComputation::Snapshot() const {
diff --git a/tensorflow/compiler/xla/client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h
index 71598ef..3ccbfb2 100644
--- a/tensorflow/compiler/xla/client/xla_computation.h
+++ b/tensorflow/compiler/xla/client/xla_computation.h
@@ -19,6 +19,7 @@
#include <utility>
#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md
index 73a9db7..bc87a60 100644
--- a/tensorflow/compiler/xla/g3doc/operation_semantics.md
+++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md
@@ -13,6 +13,22 @@
and familiar names; for example a *vector* is a 1-dimensional array and a
*matrix* is a 2-dimensional array.
+## AfterAll
+
+See also
+[`XlaBuilder::AfterAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+AfterAll takes a variadic number of tokens and produces a single token. Tokens
+are primitive types which can be threaded between side-effecting operations to
+enforce ordering. `AfterAll` can be used as a join of tokens for ordering a
+operation after a set operations.
+
+<b> `AfterAll(operands)` </b>
+
+Arguments | Type | Semantics
+---------- | ------- | -------------------------
+`operands` | `XlaOp` | variadic number of tokens
+
## AllToAll
See also
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index 6e03907..6c298e5 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -21,6 +21,7 @@
#include <string>
#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index fcc59f6..f2fcb93 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -63,6 +63,14 @@
}
}
+// Since Eigen::half doesn't satisfy the absl::bit_cast contract, we need to be
+// able to transparently access the raw 16-bit value contained within.
+template <typename T>
+T GetRawValue(T val) {
+ return val;
+}
+uint16 GetRawValue(Eigen::half val) { return val.x; }
+
} // namespace
LiteralBase::~LiteralBase() {}
@@ -1123,7 +1131,6 @@
}
}
pieces->push_back(brace_to_string("}"));
- return;
}
};
@@ -1207,16 +1214,32 @@
}
template <typename NativeSrcT, typename NativeDestT>
-typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)),
+typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT) &&
+ !std::is_same<NativeDestT, Eigen::half>::value),
Literal>::type
BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
auto converter = [](NativeSrcT src) {
- return absl::bit_cast<NativeDestT>(src);
+ return absl::bit_cast<NativeDestT>(GetRawValue(src));
};
return ConvertBetweenNativeTypesWithConverter<NativeSrcT, NativeDestT>(
src_literal, converter);
}
+template <typename NativeSrcT, typename NativeDestT>
+typename std::enable_if<(sizeof(NativeSrcT) == sizeof(Eigen::half) &&
+ std::is_same<NativeDestT, Eigen::half>::value),
+ Literal>::type
+BitcastBetweenNativeTypes(const LiteralBase& src_literal) {
+ // Eigen::half doesn't satisfy the absl::bit_cast contract, so explicitly
+ // cast to unsigned short and then use raw_uint16_to_half.
+ auto converter = [](NativeSrcT src) {
+ return Eigen::half_impl::raw_uint16_to_half(
+ absl::bit_cast<uint16>(GetRawValue(src)));
+ };
+ return ConvertBetweenNativeTypesWithConverter<NativeSrcT, Eigen::half>(
+ src_literal, converter);
+}
+
// This template specialization is here to make the compiler happy. bit_cast has
// a static check that the types are the same size. This specialization should
// never be used because the source and destination types are checked for
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
index b507a2e..ac342bf 100644
--- a/tensorflow/compiler/xla/protobuf_util.cc
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -40,16 +40,6 @@
namespace {
-string SanitizeFilename(const string& file_name) {
- string safe_file_name = file_name;
- for (char& c : safe_file_name) {
- if (c == '/' || c == '\\') {
- c = '_';
- }
- }
- return safe_file_name;
-}
-
std::pair<tensorflow::mutex*, std::vector<std::function<string(string)>>*>
GetDirectoryExpanders() {
static auto* mutex = new tensorflow::mutex;
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 4d2a37c..2768ed6 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -487,12 +487,13 @@
xrt::XLAComputation c;
auto config = c.mutable_config();
- auto shapes = config->mutable_program_shape();
+ ProgramShape shapes;
for (auto& shape : argument_shapes) {
- *shapes->add_parameters() = shape;
+ *shapes.add_parameters() = shape;
}
- TF_ASSIGN_OR_RETURN(*shapes->mutable_result(), GetReturnValueShape());
- LayoutUtil::SetToDefaultLayout(shapes);
+ TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape());
+ LayoutUtil::SetToDefaultLayout(&shapes);
+ *config->mutable_program_shape() = shapes.ToProto();
auto snapshot = computation().Snapshot().ValueOrDie();
*c.mutable_hlo_snapshot() = *snapshot;
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index feabfdb..5c2538d 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -921,22 +921,22 @@
$1 = NULL;
} else {
if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) {
- build_options.set_generate_hlo_graph(std::move(s));
+ build_options.mutable_debug_options()->set_xla_generate_hlo_graph(std::move(s));
})) {
return nullptr;
}
if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) {
- build_options.set_dump_optimized_hlo_proto_to(std::move(s));
+ build_options.mutable_debug_options()->set_xla_dump_optimized_hlo_proto_to(std::move(s));
})) {
return nullptr;
}
if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) {
- build_options.set_dump_unoptimized_hlo_proto_to(std::move(s));
+ build_options.mutable_debug_options()->set_xla_dump_unoptimized_hlo_proto_to(std::move(s));
})) {
return nullptr;
}
if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) {
- build_options.set_dump_per_pass_hlo_proto_to(std::move(s));
+ build_options.mutable_debug_options()->set_xla_dump_per_pass_hlo_proto_to(std::move(s));
})) {
return nullptr;
}
@@ -950,7 +950,7 @@
PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.hlo_profile must be a bool or None.");
SWIG_fail;
}
- build_options.set_hlo_profile(o == Py_True);
+ build_options.mutable_debug_options()->set_xla_hlo_profile(o == Py_True);
}
Py_DECREF(o);
@@ -992,6 +992,7 @@
%unignore xla::swig::XrtAllocation;
%unignore xla::swig::XrtAllocation::FromLiteral;
%unignore xla::swig::XrtAllocation::ToLiteral;
+%unignore xla::swig::XrtAllocation::shape;
%unignore xla::swig::XrtAllocationTuple;
%unignore xla::swig::XrtAllocationTuple::Release;
%unignore xla::swig::XrtAllocationTuple::size;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 92b0685..5994e55 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -26,6 +26,9 @@
import numpy as np
+import six
+from six.moves import xrange
+
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.compiler.xla.python import pywrap_xla as c_api
from tensorflow.compiler.xla.service import hlo_pb2
@@ -75,6 +78,13 @@
source_line=lineno)
+def _maybe_encode_string(s):
+ if six.PY3:
+ return s.encode('utf-8')
+ else:
+ return s
+
+
class PaddingType(enum.Enum):
VALID = 1
SAME = 2
@@ -225,7 +235,8 @@
"""Allocate and copy to XLA the given python value."""
pyval = require_numpy_array_layout(pyval)
if backend.backend_type == BackendType.XRT:
- cbuf = c_api.XrtAllocation.FromLiteral(pyval, backend.target)
+ cbuf = c_api.XrtAllocation.FromLiteral(
+ pyval, _maybe_encode_string(backend.target))
else:
cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None)
return LocalBuffer(cbuf, backend)
@@ -245,8 +256,8 @@
"""Assuming a tuple buffer, unpack it into constituent tuple elements."""
assert self.c_buffer is not None
if self._backend.backend_type == BackendType.XRT:
- result = c_api.DestructureXrtAllocationTuple(self.c_buffer,
- self._backend.target)
+ result = c_api.DestructureXrtAllocationTuple(
+ self.c_buffer, _maybe_encode_string(self._backend.target))
else:
result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer)
self.delete()
@@ -322,6 +333,9 @@
def __ne__(self, other):
return not self == other
+ def __hash__(self):
+ return hash((self._dtype, self._dimensions, self._minor_to_major))
+
def __repr__(self):
return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, '
'_is_tuple={!r}, _minor_to_major={!r})').format(
@@ -541,10 +555,13 @@
]
result_shape = result_shape.map_leaves(layout_fn)
+ argument_shapes = list(argument_shapes)
+
compile_options = compile_options or CompileOptions()
compile_options.result_shape = result_shape
if self._backend.backend_type == BackendType.XRT:
- c = self.computation.CompileForXrt(argument_shapes, self._backend.target)
+ c = self.computation.CompileForXrt(
+ argument_shapes, _maybe_encode_string(self._backend.target))
else:
c = self.computation.Compile(argument_shapes, compile_options)
return LocalComputation(c, is_compiled=True, backend=self._backend)
@@ -1380,6 +1397,7 @@
Raises:
A runtime exception if the XLA service has already been initialized.
"""
+ platform_name = _maybe_encode_string(platform_name)
c_api.InitializePlatformName(platform_name)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 56bf3a9..a348bcf 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -96,6 +96,11 @@
valid_bitcast_callback(operand->shape(), instr->shape());
}
+bool IsUnstridedSlice(const HloInstruction* hlo) {
+ return absl::c_all_of(hlo->slice_strides(),
+ [](int64 stride) { return stride == 1; });
+}
+
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
// algebraic expressions to simplified forms. Note: This only supports
// simplifications that simply look at the operands of an instruction. For the
@@ -520,7 +525,74 @@
VLOG(10) << "trying to replace " << concatenate->ToString() << " with "
<< replacement->ToString();
ReplaceInstructionIfSameShape(concatenate, replacement);
- } else if (operands.size() == 2) {
+ return Status::OK();
+ }
+
+ // Check if we can merge "adjacent" slice operands which take slices from the
+ // same other op. For simplicity we only merge unstrided slices.
+ int64 concatenate_dimension = concatenate->concatenate_dimension();
+ for (int64 i = 0; i < operands.size(); ++i) {
+ if (operands[i]->opcode() != HloOpcode::kSlice ||
+ !IsUnstridedSlice(operands[i])) {
+ continue;
+ }
+ int64 slice_end = operands[i]->slice_limits(concatenate_dimension);
+ HloInstruction* slice_operand = operands[i]->mutable_operand(0);
+ int64 j = i + 1;
+ while (j < operands.size() && operands[j]->opcode() == HloOpcode::kSlice &&
+ IsUnstridedSlice(operands[j]) &&
+ operands[j]->operand(0) == slice_operand &&
+ operands[j]->slice_starts(concatenate_dimension) == slice_end) {
+ // Check that all the slice_start values are the same in all other
+ // dimensions. This implies that the slice_limit values are also the same,
+ // because operands of concatenate need to have the same shape, and we
+ // already checked that the slices are unstrided.
+ bool same_other_starts = true;
+ for (int64 k = 0; k < operands[j]->slice_starts().size(); ++k) {
+ if (k == concatenate_dimension) {
+ continue;
+ }
+ if (operands[i]->slice_starts(k) != operands[j]->slice_starts(k)) {
+ same_other_starts = false;
+ break;
+ }
+ }
+ if (!same_other_starts) {
+ break;
+ }
+ slice_end = operands[j]->slice_limits(concatenate_dimension);
+ ++j;
+ }
+ if (j - i > 1) {
+ Shape new_slice_shape = operands[i]->shape();
+ new_slice_shape.set_dimensions(
+ concatenate_dimension,
+ slice_end - operands[i]->slice_starts(concatenate_dimension));
+ auto new_limit_indices = operands[i]->slice_limits();
+ new_limit_indices[concatenate_dimension] = slice_end;
+ auto new_slice_op =
+ computation_->AddInstruction(HloInstruction::CreateSlice(
+ new_slice_shape, slice_operand,
+ /*start_indices=*/operands[i]->slice_starts(),
+ /*limit_indices=*/new_limit_indices,
+ /*strides=*/operands[i]->slice_strides()));
+ std::vector<HloInstruction*> new_operands;
+ for (int64 k = 0; k < i; ++k) {
+ new_operands.push_back(operands[k]);
+ }
+ new_operands.push_back(new_slice_op);
+ for (int64 k = j; k < operands.size(); ++k) {
+ new_operands.push_back(operands[k]);
+ }
+ auto replacement =
+ computation_->AddInstruction(concatenate->CloneWithNewOperands(
+ concatenate->shape(), new_operands));
+ ReplaceInstructionIfSameShape(concatenate, replacement);
+ return Status::OK();
+ }
+ }
+
+ if (operands.size() == 2) {
// A binary concat with a broadcasted scalar as an operand can be converted
// into a pad which is simpler to fold into other operations.
bool is_effective_low_pad = Match(
@@ -536,7 +608,7 @@
padding_config_dim->set_edge_padding_high(0);
padding_config_dim->set_edge_padding_low(0);
padding_config_dim->set_interior_padding(0);
- if (dim == concatenate->concatenate_dimension()) {
+ if (dim == concatenate_dimension) {
if (is_effective_low_pad) {
padding_config_dim->set_edge_padding_low(
operands[0]->shape().dimensions(dim));
@@ -1599,6 +1671,27 @@
pad, HloInstruction::CreateBroadcast(pad->shape(),
pad->mutable_operand(1), {}));
}
+
+ // Interior padding on one sized dimensions have no effect. As a result it
+ // makes other simplifications possible if there is no interior padding.
+ if (HasInteriorPadding(pad->padding_config())) {
+ PaddingConfig padding_config = pad->padding_config();
+ bool cleared_interior_padding = false;
+ for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
+ if (padding_config.dimensions(i).interior_padding() > 0 &&
+ pad->operand(0)->shape().dimensions(i) == 1) {
+ cleared_interior_padding = true;
+ padding_config.mutable_dimensions(i)->set_interior_padding(0);
+ }
+ }
+ if (cleared_interior_padding) {
+ return ReplaceWithNewInstruction(
+ pad,
+ HloInstruction::CreatePad(pad->shape(), pad->mutable_operand(0),
+ pad->mutable_operand(1), padding_config));
+ }
+ }
+
// Eliminate nop pads (padding all zero), and replace a pad with negative
// padding with a pad with non-negative padding followed by a slice.
bool all_zero = true;
@@ -2010,11 +2103,6 @@
return false;
}
-bool IsUnstridedSlice(const HloInstruction* hlo) {
- return absl::c_all_of(hlo->slice_strides(),
- [](int64 stride) { return stride == 1; });
-}
-
StatusOr<bool> AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape(
HloInstruction* slice) {
CHECK_EQ(slice->opcode(), HloOpcode::kSlice);
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 8b8ba2a..48f689c 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1437,6 +1437,76 @@
EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1));
}
+TEST_F(AlgebraicSimplifierTest, SimplifyConcatenateOfSlices) {
+ auto m = CreateNewVerifiedModule();
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {100, 99});
+ Shape concat_shape = ShapeUtil::MakeShape(F32, {50, 80});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r2f32, "param1"));
+
+ HloInstruction* slice0 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{0, 0},
+ /*limit_indices=*/{50, 10}, /*strides=*/{1, 1}));
+
+ // Cannot merge 'slice0' and 'slice1' because of different start indices in
+ // dimension 0.
+ HloInstruction* slice1 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 10},
+ /*limit_indices=*/{100, 20}, /*strides=*/{1, 1}));
+
+ // Cannot merge 'slice1' and 'slice2' because of stride in dimension 2.
+ HloInstruction* slice2 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 20},
+ /*limit_indices=*/{100, 40}, /*strides=*/{1, 2}));
+
+ // Cannot merge 'slice2' and 'slice3' because of stride in dimension 2.
+ HloInstruction* slice3 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 40},
+ /*limit_indices=*/{100, 50}, /*strides=*/{1, 1}));
+
+ // Can merge 'slice3' and 'slice4'.
+ HloInstruction* slice4 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 50},
+ /*limit_indices=*/{100, 60}, /*strides=*/{1, 1}));
+
+ // Can merge 'slice4' and 'slice5'.
+ HloInstruction* slice5 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 60},
+ /*limit_indices=*/{100, 70}, /*strides=*/{1, 1}));
+
+ // Cannot merge 'slice5' and 'slice6' because of overlap.
+ HloInstruction* slice6 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {50, 10}), param0, /*start_indices=*/{50, 69},
+ /*limit_indices=*/{100, 79}, /*strides=*/{1, 1}));
+
+ // Cannot merge 'slice6' and 'slice7' because of slicing from a different
+ // parameter.
+ HloInstruction* slice7 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {50, 10}), param1, /*start_indices=*/{50, 79},
+ /*limit_indices=*/{100, 89}, /*strides=*/{1, 1}));
+
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ concat_shape,
+ {slice0, slice1, slice2, slice3, slice4, slice5, slice6, slice7}, 1));
+ auto computation = m->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(default_options_);
+ ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie());
+ EXPECT_THAT(
+ computation->root_instruction(),
+ op::Concatenate(op::Slice(param0), op::Slice(param0), op::Slice(param0),
+ op::Slice(param0), op::Slice(param0), op::Slice(param1)));
+ // The operand 3 should be a merge of 'slice3', 'slice4' and 'slice5', so its
+ // shape should have dimensions {50, 30}.
+ EXPECT_TRUE(
+ ShapeUtil::Equal(computation->root_instruction()->operand(3)->shape(),
+ ShapeUtil::MakeShape(F32, {50, 30})));
+ EXPECT_EQ(computation->root_instruction()->operand(3)->slice_starts(1), 40);
+}
+
// Test that a simplification which changes layouts is not performed if layout
// sensitive is true.
TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
@@ -2119,6 +2189,40 @@
has_negative_padding(computation->root_instruction()->operand(0)));
}
+TEST_F(AlgebraicSimplifierTest, TrivialInteriorPadding) {
+ // Verify that a pad instruction with interior padding on one-sized
+ // dimensions, removes the interior padding.
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 1}), "param"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ PaddingConfig padding;
+ for (int i = 0; i < 2; ++i) {
+ auto dimension = padding.add_dimensions();
+ dimension->set_edge_padding_low(3);
+ dimension->set_edge_padding_high(3);
+ dimension->set_interior_padding(i * 3);
+ }
+ HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {8, 7}), param, zero, padding));
+
+ auto module = CreateNewVerifiedModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(default_options_);
+
+ ASSERT_THAT(computation->root_instruction(), op::Pad(param, zero));
+ ASSERT_TRUE(HasInteriorPadding(pad->padding_config()));
+
+ EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
+ EXPECT_FALSE(
+ HasInteriorPadding(computation->root_instruction()->padding_config()));
+}
+
TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
HloComputation::Builder builder(TestName());
HloInstruction* param =
diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc
index 6713227..0237f16 100644
--- a/tensorflow/compiler/xla/service/compile_only_service.cc
+++ b/tensorflow/compiler/xla/service/compile_only_service.cc
@@ -86,15 +86,15 @@
Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot));
}
- const auto& program_shape = instance.computation.host_program_shape();
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
*execution_options.mutable_shape_with_output_layout() =
*instance.result_layout;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(program_shape, instance.argument_layouts,
- &execution_options));
+ CreateModuleConfig(
+ ProgramShape(instance.computation.host_program_shape()),
+ instance.argument_layouts, &execution_options));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModule> hlo_module,
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
index 7f7f150..10c53f1 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc
@@ -142,16 +142,16 @@
// Finally we use the Eq op of these two broadcasted constants and get the
// desired mask.
HloInstruction* GetExpandedFilterMask(
- const Shape& filter_shape, int64 input_feature_dim,
- int64 output_feature_dim, int64 group_count,
+ const Shape& filter_shape, int64 kernel_input_feature_dim,
+ int64 kernel_output_feature_dim, int64 group_count,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
add_instruction) {
Shape expanded_filter_shape =
- ExpandedFilterShape(filter_shape, group_count, input_feature_dim);
+ ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim);
Shape mask_shape = ShapeUtil::MakeShape(
S32, AsInt64Slice(expanded_filter_shape.dimensions()));
- int64 output_feature = filter_shape.dimensions(output_feature_dim);
- int64 group_size = filter_shape.dimensions(input_feature_dim);
+ int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim);
+ int64 group_size = filter_shape.dimensions(kernel_input_feature_dim);
// Create a 'input_feature' sized linspace and 'output_feature' sized linspace
// that will be broadcasted into perpendicular dimensions and compared.
@@ -159,15 +159,14 @@
GetMaskIds(group_size, group_count);
const std::vector<int32> output_feature_filter_mask =
GetMaskIds(output_feature / group_count, group_count);
-
auto mask1 = add_instruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int32>(input_feature_filter_mask)));
- auto broadcasted_mask1 = add_instruction(
- HloInstruction::CreateBroadcast(mask_shape, mask1, {input_feature_dim}));
+ auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast(
+ mask_shape, mask1, {kernel_input_feature_dim}));
auto mask2 = add_instruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<int32>(output_feature_filter_mask)));
- auto broadcasted_mask2 = add_instruction(
- HloInstruction::CreateBroadcast(mask_shape, mask2, {output_feature_dim}));
+ auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast(
+ mask_shape, mask2, {kernel_output_feature_dim}));
// Compare the broadcasted output feature linspace to the input feature
// linspace to create a diagonal predicate.
@@ -189,18 +188,20 @@
};
auto dim_numbers = convolution->convolution_dimension_numbers();
- int64 input_feature_dim = dim_numbers.kernel_input_feature_dimension();
- int64 group_size = filter->shape().dimensions(input_feature_dim);
- int64 output_feature_dim = dim_numbers.kernel_output_feature_dimension();
- auto expanded_filter_shape =
- ExpandedFilterShape(filter->shape(), group_count, input_feature_dim);
- HloInstruction* filter_mask = GetExpandedFilterMask(
- filter->shape(), input_feature_dim, output_feature_dim, group_count, add);
+ int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension();
+ int64 group_size = filter->shape().dimensions(kernel_input_feature_dim);
+ int64 kernel_output_feature_dim =
+ dim_numbers.kernel_output_feature_dimension();
+ auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count,
+ kernel_input_feature_dim);
+ HloInstruction* filter_mask =
+ GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim,
+ kernel_output_feature_dim, group_count, add);
HloInstruction* expanded_filter;
if (group_size == 1) {
bool depthwise_separable =
- (group_count == filter->shape().dimensions(output_feature_dim));
+ (group_count == filter->shape().dimensions(kernel_output_feature_dim));
// If the code generator handles depthwise separable convolutions
// inherently, then no filter expansion is needed.
if (!filter_expansion_ && depthwise_separable) {
@@ -241,39 +242,108 @@
// We want to repeat 'filter' in the 'input_feature_dim' dimension
// 'group_count' times.
Shape reshaped_filter_shape =
- ShapeUtil::DeleteDimension(input_feature_dim, filter->shape());
+ ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape());
auto reshaped_filter =
add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
std::vector<int64> broadcast_dims;
for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) {
- if (i == input_feature_dim) {
+ if (i == kernel_input_feature_dim) {
continue;
}
broadcast_dims.push_back(i);
}
expanded_filter = add(HloInstruction::CreateBroadcast(
expanded_filter_shape, reshaped_filter, broadcast_dims));
+
+ auto zero = add(HloInstruction::CreateConstant(
+ LiteralUtil::Zero(expanded_filter_shape.element_type())));
+ auto zero_filter =
+ add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
+ auto new_filter = add(HloInstruction::CreateTernary(
+ expanded_filter_shape, HloOpcode::kSelect, filter_mask, expanded_filter,
+ zero_filter));
+
+ auto new_convolution = HloInstruction::CreateConvolve(
+ convolution->shape(), convolution->mutable_operand(0), new_filter,
+ /*feature_group_count=*/1, convolution->window(), dim_numbers,
+ convolution->precision_config());
+ TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
+ convolution, std::move(new_convolution)));
} else {
- // We could possibly also use reshape, broadcast, reshape instead of concat
- // here, but it would require more complex code, and for depthwise
- // convolution we would never end up in this branch.
- std::vector<HloInstruction*> concat_operands(group_count, filter);
- expanded_filter = add(HloInstruction::CreateConcatenate(
- expanded_filter_shape, concat_operands, input_feature_dim));
+ // The filter expansion mechanism adds zeroes in the kernel.
+ // For an OF = 12, IF = 6, and kernel IF = 2, the expanded filter mask
+ // would look like (IF on the Y-axis, OF on the X-axis)
+ // 1 1 1 1 0 0 0 0 0 0 0 0
+ // 1 1 1 1 0 0 0 0 0 0 0 0
+ // 0 0 0 0 1 1 1 1 0 0 0 0
+ // 0 0 0 0 1 1 1 1 0 0 0 0
+ // 0 0 0 0 0 0 0 0 1 1 1 1
+ // 0 0 0 0 0 0 0 0 1 1 1 1
+ //
+ // Instead of convolving the above with the input, we instead slice the
+ // kernel into three kernels, each containing islands of 1s from the filter
+ // above. We also slice the activations in the IF dimension with each slice
+ // of size = group_size. For each slice, we perform convolutions, and
+ // concatenate the generated outputs in the output OF dimension.
+
+ std::vector<HloInstruction*> sliced_convolutions;
+ auto activation = convolution->mutable_operand(0);
+ std::vector<int64> slice_strides(filter->shape().dimensions_size(), 1);
+ std::vector<int64> filter_slice_starts(filter->shape().dimensions_size(),
+ 0);
+ std::vector<int64> filter_slice_limits(filter->shape().dimensions().begin(),
+ filter->shape().dimensions().end());
+ std::vector<int64> activation_slice_starts(
+ activation->shape().dimensions_size(), 0);
+ std::vector<int64> activation_slice_limits(
+ activation->shape().dimensions().begin(),
+ activation->shape().dimensions().end());
+
+ int64 output_feature =
+ filter->shape().dimensions(kernel_output_feature_dim);
+ auto output_feature_dim = dim_numbers.output_feature_dimension();
+ int64 filter_slice_width = output_feature / group_count;
+
+ int64 activation_input_feature_dim = dim_numbers.input_feature_dimension();
+
+ for (int64 i = 0; i < group_count; i++) {
+ filter_slice_starts[kernel_output_feature_dim] = i * filter_slice_width;
+ filter_slice_limits[kernel_output_feature_dim] =
+ (i + 1) * filter_slice_width;
+ auto filter_sliced_shape = filter->shape();
+ filter_sliced_shape.set_dimensions(kernel_output_feature_dim,
+ filter_slice_width);
+ auto filter_slice = add(HloInstruction::CreateSlice(
+ filter_sliced_shape, filter, filter_slice_starts, filter_slice_limits,
+ slice_strides));
+
+ activation_slice_starts[activation_input_feature_dim] = i * group_size;
+ activation_slice_limits[activation_input_feature_dim] =
+ (i + 1) * group_size;
+ auto activation_sliced_shape = activation->shape();
+ activation_sliced_shape.set_dimensions(activation_input_feature_dim,
+ group_size);
+ auto activation_slice = add(HloInstruction::CreateSlice(
+ activation_sliced_shape, activation, activation_slice_starts,
+ activation_slice_limits, slice_strides));
+
+ auto conv_slice_shape = convolution->shape();
+ conv_slice_shape.set_dimensions(output_feature_dim, filter_slice_width);
+
+ auto new_convolution = add(HloInstruction::CreateConvolve(
+ conv_slice_shape, activation_slice, filter_slice,
+ /*feature_group_count=*/1, convolution->window(), dim_numbers,
+ convolution->precision_config()));
+
+ sliced_convolutions.push_back(new_convolution);
+ }
+
+ auto new_conv = HloInstruction::CreateConcatenate(
+ convolution->shape(), sliced_convolutions, output_feature_dim);
+ TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
+ convolution, std::move(new_conv)));
}
- auto zero = add(HloInstruction::CreateConstant(
- LiteralUtil::Zero(expanded_filter_shape.element_type())));
- auto zero_filter =
- add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
- auto new_filter = add(
- HloInstruction::CreateTernary(expanded_filter_shape, HloOpcode::kSelect,
- filter_mask, expanded_filter, zero_filter));
- auto new_convolution = HloInstruction::CreateConvolve(
- convolution->shape(), convolution->mutable_operand(0), new_filter,
- /*feature_group_count=*/1, convolution->window(), dim_numbers,
- convolution->precision_config());
- TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
- convolution, std::move(new_convolution)));
+
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc
index 28373eb..e6bf214 100644
--- a/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc
+++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter_test.cc
@@ -82,18 +82,14 @@
ConvolutionFeatureGroupConverter converter;
ASSERT_TRUE(converter.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
- // Make sure the convolution is converted to one with feature_group_count = 1.
- EXPECT_EQ(root->opcode(), HloOpcode::kConvolution);
- EXPECT_EQ(root->feature_group_count(), 1);
- // Verify that the filter operand has been replaced.
- EXPECT_THAT(root->operand(1),
- op::Select(op::Eq(op::Broadcast(op::Constant()),
- op::Broadcast(op::Constant())),
- // We expect to see Concatenate here instead of
- // Broadcast, because feature_group_count < input
- // feature dimension.
- op::Concatenate(op::Parameter(), op::Parameter()),
- op::Broadcast(op::Constant())));
+ // Make sure the convolution is replaced with a concatenate.
+ EXPECT_EQ(root->opcode(), HloOpcode::kConcatenate);
+ // And the operands of the concatenate are convolutions, each with a feature
+ // group count = 1.
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConvolution);
+ EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kConvolution);
+ EXPECT_EQ(root->operand(0)->feature_group_count(), 1);
+ EXPECT_EQ(root->operand(1)->feature_group_count(), 1);
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 2852fc8..796a7cf 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -61,17 +61,6 @@
// TODO(b/64227304) Creating a custom pass pipeline will replace this.
namespace {
-
-// TODO(sanjoy): remove this class.
-class FilteredFunctionPassManager : public llvm::legacy::FunctionPassManager {
- public:
- explicit FilteredFunctionPassManager(llvm::Module* m)
- : llvm::legacy::FunctionPassManager(m) {}
- void add(llvm::Pass* p) override {
- llvm::legacy::FunctionPassManager::add(p);
- }
-};
-
class FilteredPassManager : public llvm::legacy::PassManager {
public:
explicit FilteredPassManager(bool disable_expensive_passes)
@@ -94,7 +83,7 @@
std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
llvm::Module& module) const {
FilteredPassManager module_passes(disable_expensive_passes_);
- FilteredFunctionPassManager function_passes(&module);
+ llvm::legacy::FunctionPassManager function_passes(&module);
VLOG(2) << "IR before optimizations";
XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
index f9cd61b..6f79ad7 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -48,10 +48,15 @@
(hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1);
}
+bool HasExactlyOneUse(const HloInstruction& hlo_instr) {
+ return hlo_instr.user_count() == 1 &&
+ absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1;
+}
+
bool CanBeOutputFused(const HloInstruction* producer,
const HloInstruction* consumer) {
return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) &&
- producer->user_count() == 1;
+ HasExactlyOneUse(*producer) == 1;
}
bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index c77d598..527df0b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -763,6 +763,28 @@
Not(op::Fusion()));
}
+TEST_F(InstructionFusionTest,
+ DotOperationFusion_DontOutputFuseDuplicateOperands) {
+ absl::string_view module_string = R"(
+HloModule module
+
+ENTRY main {
+ a = f32[50,60]{1,0} parameter(0)
+ b = f32[60,1]{1,0} parameter(1)
+ c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT d = f32[50,1]{1,0} add(c, c)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+ TF_ASSERT_OK_AND_ASSIGN(bool fused_something,
+ CpuInstructionFusion().Run(module.get()));
+ EXPECT_FALSE(fused_something);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ Not(op::Fusion()));
+}
+
struct GatherLoopFusionTestSpec {
string test_name;
string hlo_computation_text;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index cf97a8b..4032c2d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -2565,10 +2565,17 @@
return Status::OK();
}
-Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) {
- TF_RET_CHECK(ByteSizeOf(gen_token->shape()) == 0);
+Status IrEmitter::HandleAfterAll(HloInstruction* after_all) {
+ TF_RET_CHECK(ByteSizeOf(after_all->shape()) == 0);
// No code to generate, but we need to emit an address for book-keeping.
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(gen_token));
+ TF_RETURN_IF_ERROR(EmitTargetAddressForOp(after_all));
+ return Status::OK();
+}
+
+Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
+ // AddDedendency just forwards its zero-th operand.
+ emitted_value_[add_dependency] =
+ GetEmittedValueFor(add_dependency->operand(0));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index f529c61..559a816 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -159,7 +159,8 @@
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConditional(HloInstruction* conditional) override;
Status HandleScatter(HloInstruction* scatter) override;
- Status HandleAfterAll(HloInstruction* gen_token) override;
+ Status HandleAfterAll(HloInstruction* after_all) override;
+ Status HandleAddDependency(HloInstruction* add_dependency) override;
Status HandleRng(HloInstruction* rng) override;
Status FinishVisit(HloInstruction* root) override;
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
index c7fc101..722aa31 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc
@@ -51,19 +51,21 @@
// then y is ordered as an int32 such that finite values have the
// obvious order, -0 is ordered before 0, and -NaN and NaN appear at
// the beginning and end of the ordering.
-template <typename CastType, typename KeyType>
+template <typename CastType, typename UnsignedCastType, typename KeyType>
CastType Convert(KeyType value) {
CastType casted_value;
memcpy(&casted_value, &value, sizeof(CastType));
if (casted_value < 0) {
- return std::numeric_limits<CastType>::max() - casted_value;
+ return static_cast<UnsignedCastType>(std::numeric_limits<CastType>::max()) -
+ casted_value;
}
return casted_value;
}
-template <typename CastType, typename KeyType>
+template <typename CastType, typename UnsignedCastType, typename KeyType>
bool LessThan(KeyType lhs, KeyType rhs) {
- return Convert<CastType>(lhs) < Convert<CastType>(rhs);
+ return Convert<CastType, UnsignedCastType>(lhs) <
+ Convert<CastType, UnsignedCastType>(rhs);
}
template <>
@@ -71,7 +73,7 @@
std::stable_sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<double, int64>& lhs,
const std::pair<double, int64>& rhs) -> bool {
- return LessThan<int64>(lhs.first, rhs.first);
+ return LessThan<int64, uint64>(lhs.first, rhs.first);
});
}
@@ -80,7 +82,7 @@
std::stable_sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<float, int64>& lhs,
const std::pair<float, int64>& rhs) -> bool {
- return LessThan<int32>(lhs.first, rhs.first);
+ return LessThan<int32, uint32>(lhs.first, rhs.first);
});
}
@@ -90,7 +92,7 @@
std::stable_sort(row_to_sort, row_to_sort + num_elements,
[](const std::pair<Eigen::half, int64>& lhs,
const std::pair<Eigen::half, int64>& rhs) -> bool {
- return LessThan<int32>(
+ return LessThan<int32, uint32>(
Eigen::half_impl::half_to_float(lhs.first),
Eigen::half_impl::half_to_float(rhs.first));
});
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
index f77641e..efccade 100644
--- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -128,8 +128,18 @@
}
llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
- void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
+ void* func_addr = nullptr;
+ if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) {
+ // On Mac OS X, 'name' may have a leading underscore prefix, even though the
+ // registered name may not.
+ std::string stripped_name(name.begin() + 1, name.end());
+ func_addr = CustomCallTargetRegistry::Global()->Lookup(stripped_name);
+ } else {
+ func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
+ }
+
if (func_addr == nullptr) {
+ VLOG(2) << "Unable to resolve runtime symbol: " << name;
return nullptr;
}
llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index d637128..e84bf00 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -251,6 +251,7 @@
virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
+ virtual Status HandleAddDependency(HloInstructionPtr add_dependency) = 0;
virtual Status HandleAfterAll(HloInstructionPtr token) = 0;
// Invoked to inform the visitor that the traversal has completed, and that
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index e57184f..80ea5be 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -206,6 +206,9 @@
Status HandleGetDimensionSize(HloInstructionPtr get_size) override {
return DefaultAction(get_size);
}
+ Status HandleAddDependency(HloInstructionPtr add_dependency) override {
+ return DefaultAction(add_dependency);
+ }
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
index 30c1f90..4704579 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc
@@ -229,7 +229,7 @@
if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
return user->opcode() == HloOpcode::kFusion &&
(user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
- (user->fusion_kind() == HloInstruction::FusionKind::kInput &&
+ (IsReduceInputFusion(*user) &&
LayoutsAreReduceInputFusionFriendly(*fusion, *user)));
})) {
VLOG(3) << "Not merging " << fusion->name()
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
index 2d31fd5..392b149 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc
@@ -55,7 +55,7 @@
});
}
-bool IsInputFusibleReduction(const HloInstruction& instr) {
+bool IsReduceInputFusion(const HloInstruction& instr) {
if (instr.IsMultiOutputFusion()) {
for (const HloInstruction* operand :
instr.fused_expression_root()->operands()) {
@@ -67,17 +67,18 @@
return true;
}
}
- return false;
- } else if (instr.opcode() == HloOpcode::kFusion) {
- if (IsReductionToVector(*instr.fused_expression_root())) {
- CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
- << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
- << instr.ToString();
- return true;
- }
- return false;
+ } else if (instr.opcode() == HloOpcode::kFusion &&
+ IsReductionToVector(*instr.fused_expression_root())) {
+ CHECK(instr.fusion_kind() == HloInstruction::FusionKind::kInput)
+ << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
+ << instr.ToString();
+ return true;
}
- return IsReductionToVector(instr);
+ return false;
+}
+
+bool IsInputFusibleReduction(const HloInstruction& instr) {
+ return IsReduceInputFusion(instr) || IsReductionToVector(instr);
}
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
index f7c24a0..c0be354 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h
@@ -33,14 +33,17 @@
bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
const HloInstruction& reduce);
-// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr`
-// is either an unfused reduction-to-vector op, an input fusion rooted at a
-// reduction-to-vector op, or a multi-output input fusion with at least one
-// reduction-to-vector op root.
// Note that reduction ops are lowered in different ways. Reduce input fusions
// are lowered by IrEmitterUnnested::EmitReductionToVector and must be rooted at
// reduction-to-vector ops. Other reduction ops are lowered by
// GpuElementalIrEmitter and fused like elementwise ops.
+
+// Whether `instr` is an input fusion rooted at a reduction-to-vector op or a
+// multi-output input fusion with at least one reduction-to-vector op root.
+bool IsReduceInputFusion(const HloInstruction& instr);
+
+// Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr`
+// is either an unfused reduction-to-vector op or a reduce input fusion.
bool IsInputFusibleReduction(const HloInstruction& instr);
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
index d91b7bc..1222250 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible_test.cc
@@ -178,7 +178,7 @@
EXPECT_TRUE(LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce));
}
-TEST_F(GpuFusibleTest, IsInputFusibleReduction_ReductionToVector) {
+TEST_F(GpuFusibleTest, IsReduceInputFusion_ReductionToVector) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
ENTRY entry {
c0 = f32[] parameter(0)
@@ -191,10 +191,11 @@
const HloInstruction* reduce =
module->entry_computation()->root_instruction();
ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ EXPECT_FALSE(IsReduceInputFusion(*reduce));
EXPECT_TRUE(IsInputFusibleReduction(*reduce));
}
-TEST_F(GpuFusibleTest, IsInputFusibleReduction_ElementalReduction) {
+TEST_F(GpuFusibleTest, IsReduceInputFusion_ElementalReduction) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
ENTRY entry {
c0 = f32[] parameter(0)
@@ -207,10 +208,11 @@
const HloInstruction* reduce =
module->entry_computation()->root_instruction();
ASSERT_EQ(reduce->opcode(), HloOpcode::kReduce);
+ EXPECT_FALSE(IsReduceInputFusion(*reduce));
EXPECT_FALSE(IsInputFusibleReduction(*reduce));
}
-TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputInputReduceFusion) {
+TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputInputReduceFusion) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_reduction {
c0 = f32[] parameter(0)
@@ -225,10 +227,11 @@
const HloInstruction* reduce =
module->entry_computation()->root_instruction();
ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsReduceInputFusion(*reduce));
EXPECT_TRUE(IsInputFusibleReduction(*reduce));
}
-TEST_F(GpuFusibleTest, IsInputFusibleReduction_SingleOutputLoopReduceFusion) {
+TEST_F(GpuFusibleTest, IsReduceInputFusion_SingleOutputLoopReduceFusion) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_reduction {
c0 = f32[] parameter(0)
@@ -243,10 +246,11 @@
const HloInstruction* reduce =
module->entry_computation()->root_instruction();
ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsReduceInputFusion(*reduce));
EXPECT_FALSE(IsInputFusibleReduction(*reduce));
}
-TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputInputReduceFusion) {
+TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputInputReduceFusion) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_reduction {
c0 = f32[] parameter(0)
@@ -263,11 +267,12 @@
const HloInstruction* reduce =
module->entry_computation()->root_instruction();
ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsReduceInputFusion(*reduce));
EXPECT_TRUE(IsInputFusibleReduction(*reduce));
}
TEST_F(GpuFusibleTest,
- IsInputFusibleReduction_MultiOutputInputReduceFusionWithExtraOutputs) {
+ IsReduceInputFusion_MultiOutputInputReduceFusionWithExtraOutputs) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_reduction {
c0 = f32[] parameter(0)
@@ -284,10 +289,11 @@
const HloInstruction* reduce =
module->entry_computation()->root_instruction();
ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_TRUE(IsReduceInputFusion(*reduce));
EXPECT_TRUE(IsInputFusibleReduction(*reduce));
}
-TEST_F(GpuFusibleTest, IsInputFusibleReduction_MultiOutputLoopReduceFusion) {
+TEST_F(GpuFusibleTest, IsReduceInputFusion_MultiOutputLoopReduceFusion) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_reduction {
c0 = f32[] parameter(0)
@@ -304,11 +310,12 @@
const HloInstruction* reduce =
module->entry_computation()->root_instruction();
ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsReduceInputFusion(*reduce));
EXPECT_FALSE(IsInputFusibleReduction(*reduce));
}
TEST_F(GpuFusibleTest,
- IsInputFusibleReduction_MultiOutputLoopFusionReduceAndElementwiseOp) {
+ IsReduceInputFusion_MultiOutputLoopFusionReduceAndElementwiseOp) {
auto module = ParseHloString(absl::StrCat(kModulePrefix, R"(
fused_reduction {
c0 = f32[] parameter(0)
@@ -325,6 +332,7 @@
const HloInstruction* reduce =
module->entry_computation()->root_instruction();
ASSERT_EQ(reduce->opcode(), HloOpcode::kFusion);
+ EXPECT_FALSE(IsReduceInputFusion(*reduce));
EXPECT_FALSE(IsInputFusibleReduction(*reduce));
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 1c0a23f..f59da2c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -65,8 +65,8 @@
VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
- // Empirically we've found with Volta and cudnn 7 that backward-input convs
- // with stride are significantly faster with NCHW layouts.
+ // Empirically we've found with Volta and cudnn <= 7.3 that backward-input
+ // convs with stride are significantly faster with NCHW layouts.
//
// We could have used a mixed layout combination, e.g. (NHWC, NCHW, NCHW),
// which on paper gives good performance. However, there are two observations:
@@ -75,11 +75,17 @@
// * we've also observed that for mixed layouts, cuDNN transposes data back
// and forth from a different layout combination. If we end up with
// transposes anyway, we prefer to have them in XLA, as they can be fused.
- // TODO(timshen): Figure out the exact condition. This may be achieved by
- // auto-tuning layouts offline.
- if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
- window_util::HasStride(instr->window())) {
- return kAllNCHW;
+ if (auto* dnn = stream_executor->AsDnn()) {
+ auto version_status = dnn->GetVersion();
+ if (version_status.ok()) {
+ auto version = version_status.ConsumeValueOrDie();
+ if (std::make_tuple(version.major_version(), version.minor_version()) <=
+ std::make_tuple(7, 3) &&
+ instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
+ window_util::HasStride(instr->window())) {
+ return kAllNCHW;
+ }
+ }
}
// For other Volta f16 convolutions, use NHWC.
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 43f43b5..6151dd8 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -80,7 +80,7 @@
// This function limits the maximum number of operands to a fusion.
//
// There's a cap on how many parameters we can pass to a CUDA kernel, but
-// exactly what that limit is is hazy, as it depends on (among other things) how
+// exactly what that limit is hazy, as it depends on (among other things) how
// much GPU constant memory is in use for other purposes.
//
// Moreover, we don't even know at the point that we're running fusion how many
@@ -181,7 +181,8 @@
return true;
}
} else if (consumer->operand_count() == 2 &&
- consumer->opcode() == HloOpcode::kAdd) {
+ consumer->opcode() == HloOpcode::kAdd &&
+ consumer->operand(other_operand_index) != producer) {
// Fuse a bias add into the output of the dot.
return true;
}
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 2b060b0..688604c 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -358,6 +358,29 @@
op::Parameter()));
}
+TEST_F(InstructionFusionTest,
+ DotOperationFusion_DontOutputFuseDuplicateOperands) {
+ absl::string_view module_string = R"(
+HloModule module
+
+ENTRY main {
+ a = f32[50,60]{1,0} parameter(0)
+ b = f32[60,1]{1,0} parameter(1)
+ c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT d = f32[50,1]{1,0} add(c, c)
+}
+)";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_string));
+ TF_ASSERT_OK_AND_ASSIGN(
+ bool fused_something,
+ GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get()));
+ EXPECT_FALSE(fused_something);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ Not(op::Fusion()));
+}
+
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
// duplicated and fused into both reduces.
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 7fcdd80..3159191 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -97,6 +97,18 @@
return Status::OK();
}
+Status IrEmitter::HandleAddDependency(HloInstruction* add_dependency) {
+ VLOG(2) << "HandleAddDependency: " << add_dependency->ToString();
+ const HloInstruction* operand = add_dependency->operand(0);
+ // Add_Dependency is a no-op, but we still want to bind it to an llvm::Value
+ // sometimes, e.g., when it's operand is a constant or a bitcast of a
+ // constant.
+ if (bindings_.BoundToIrValue(*operand)) {
+ bindings_.BindHloToIrValue(*add_dependency, GetBasePointer(*operand));
+ }
+ return Status::OK();
+}
+
Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
auto operand = get_tuple_element->operand(0);
CHECK(bindings_.BoundToIrValue(*operand));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 56c3f45..2da46c0 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -100,6 +100,7 @@
Status HandleBatchNormInference(HloInstruction* batch_norm) override;
Status HandleBatchNormTraining(HloInstruction* batch_norm) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm) override;
+ Status HandleAddDependency(HloInstruction* add_dependency) override;
Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index efe335c..bbe1583 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -22,7 +22,6 @@
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
#include "absl/algorithm/container.h"
-#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
@@ -65,6 +64,7 @@
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
@@ -547,91 +547,7 @@
// TODO(b/112040122): Support variadic reduce.
return Unimplemented("Variadic reduce is not supported on GPU");
}
- VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
- std::vector<std::unique_ptr<Thunk>> thunks;
- absl::Span<HloInstruction* const> output_instructions =
- root->opcode() == HloOpcode::kTuple
- ? root->operands()
- : absl::Span<HloInstruction* const>(&root, 1);
-
- // For multi-output fusion emit an initializer for each tuple element.
- // Otherwise it's sufficient to just initialize the single output.
- HloInstruction* first_reduce = nullptr;
- for (int i = 0, e = output_instructions.size(); i != e; ++i) {
- if (output_instructions[i]->opcode() == HloOpcode::kReduce) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<Thunk> initializer_thunk,
- BuildInitializerThunk(fusion, output_instructions[i] == root
- ? ShapeIndex()
- : ShapeIndex({i})));
- thunks.push_back(std::move(initializer_thunk));
- first_reduce =
- first_reduce == nullptr ? output_instructions[i] : first_reduce;
- }
- }
- CHECK(first_reduce != nullptr);
- std::unique_ptr<KernelThunk> kernel_thunk =
- BuildKernelThunk(fusion, /*implements_whole_instruction=*/false);
- GpuElementalIrEmitter elemental_emitter(
- hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
- GetNestedComputer());
- FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(fusion),
- &elemental_emitter);
- TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
-
- // For multi-output fusion CHECK the constraints and feed all the
- // reduces into a single loop code generator. Single-output reduce
- // fusion is a special case of that.
- InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
- InlinedVector<llvm_ir::ElementGenerator, 1> init_value_gens;
- std::vector<std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens;
- InlinedVector<HloComputation*, 1> reducers;
- InlinedVector<ShapeIndex, 1> reduce_output_shapes;
- for (int i = 0, e = output_instructions.size(); i != e; ++i) {
- const HloInstruction* inst = output_instructions[i];
- ShapeIndex output_shape_index;
- if (root->opcode() == HloOpcode::kTuple) {
- output_shape_index = {i};
- }
- if (inst->opcode() == HloOpcode::kReduce) {
- CHECK(IsReductionToVector(*inst))
- << "Only reductions to vector are supported";
- // Shapes, layouts and dimensions must be the same for all reduces
- // inside of this fusion.
- CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
- CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(),
- inst->operand(0)->shape()));
- CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(),
- inst->operand(1)->shape()));
- CHECK(first_reduce->dimensions() == inst->dimensions());
- input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
- init_value_gens.push_back(
- fused_emitter.GetGenerator(inst->operand(1)));
- reducers.push_back(inst->to_apply());
- reduce_output_shapes.push_back(std::move(output_shape_index));
- } else {
- // For extra outputs we can relax shape equality to allow different
- // types (with the same number of elements). Layouts still have to
- // match.
- CHECK(ShapeUtil::CompatibleIgnoringElementType(
- first_reduce->operand(0)->shape(), inst->shape()));
- CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
- inst->shape().layout()));
- extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
- std::move(output_shape_index));
- }
- }
- const Shape& input_shape = first_reduce->operand(0)->shape();
- TF_CHECK_OK(EmitReductionToVector(
- kernel_thunk.get(), first_reduce, input_shape, input_gens,
- init_value_gens, first_reduce->dimensions(), reducers,
- reduce_output_shapes, extra_output_gens));
- thunks.push_back(std::move(kernel_thunk));
- std::unique_ptr<SequentialThunk> sequential_thunk =
- absl::make_unique<SequentialThunk>(std::move(thunks), fusion);
- AddThunkToThunkSequence(std::move(sequential_thunk));
- return Status::OK();
+ return EmitReductionToVector(fusion);
}
default:
LOG(FATAL) << "Bad opcode for input fusion: "
@@ -701,13 +617,12 @@
}
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
- const HloInstruction* reduce, const IrArray::Index& index,
+ const HloInstruction* unnested_hlo, const IrArray::Index& index,
absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
for (int i = 0; i != extra_output_gens.size(); ++i) {
- const HloInstruction* output = reduce->parent()->FusionInstruction();
llvm::Value* extra_output_address =
- GetIrArray(*output, *output, extra_output_gens[i].second)
+ GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second)
.EmitArrayElementAddress(index, &b_,
"extra_output_element_address");
TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
@@ -717,984 +632,13 @@
return Status::OK();
}
-Status IrEmitterUnnested::EmitReductionToScalar(
- KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape,
- absl::Span<const llvm_ir::ElementGenerator> input_gens,
- absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
- absl::Span<HloComputation* const> reducers,
- absl::Span<const ShapeIndex> reduce_output_shapes,
- absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens) {
- // Number of elements processed by a single thread.
- constexpr int64 kTileSize = 16;
- int64 num_elems = ShapeUtil::ElementsIn(input_shape);
-
- // Round up the number of tiles to a multiple of the warp size. This is
- // necessary for correctness. We launch one thread per tile, and if the
- // number of threads isn't a multiple of the number of the warp size, our
- // shuffles will read from inactive threads, producing undefined values.
- int64 num_tiles =
- RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize);
-
- Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
- reduce->shape().element_type(), {num_tiles}, {0});
- LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- tiled_input_shape, ir_emitter_context_->device_description());
-
- llvm::Type* index_ty =
- GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_);
-
- auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
- return llvm::ConstantInt::get(index_ty, c);
- };
-
- // Check whether every thread will process a full tile's worth of elements
- // without reading outside the bounds of the input. If this is true, we can
- // skip some bounds checks in the final algorithm.
- bool all_threads_in_bounds = num_tiles * kTileSize == num_elems;
-
- // __global__ void full_reduce_kernel() {
- // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x;
- // x = x_in_tiles * kTileSize;
- //
- // partial_result = init_value;
- // if (all_threads_in_bounds || x + kTileSize <= num_elems) {
- // for (i = 0; i < kTileSize; ++i) {
- // partial_result = Reducer(partial_result, input[x + i]);
- // }
- // } else {
- // for (i = 0; i < kTileSize; ++i) {
- // if (x + i < num_elems) {
- // partial_result = Reducer(partial_result, input[x + i]);
- // }
- // }
- // }
- // for (i = warpSize / 2; i > 0; i /= 2) {
- // partial_result = Reducer(partial_result,
- // __shfl_down(partial_result, i));
- // }
- // if (lane_id == 0) {
- // AtomicReducer(&output[y], partial_result);
- // }
- // }
- //
- // // Choose num_blocks and threads_per_block such that:
- // //
- // // num_blocks * threads_per_block =
- // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize),
- // //
- // // and threads_per_block is a multiple of warpSize.
- // reduce_kernel //
- auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
- const int num_reduces = reducers.size();
- llvm::Type* element_ir_type =
- llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
- std::vector<llvm::Value*> partial_reduction_result_addresses;
- for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result_address =
- Alloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
- init_value_gens[i](IrArray::Index(index_ty)));
- Store(init_ir_value, partial_reduction_result_address);
- partial_reduction_result_addresses.push_back(
- partial_reduction_result_address);
- }
-
- llvm::Value* x_in_tiles = tile_index[0];
- x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
-
- // Emit an inner for-loop that reduces the elements in the tile.
- auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
- std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop(
- "element_id_in_tile", index_typed_constant(0),
- index_typed_constant(kTileSize), index_typed_constant(1), &b_);
-
- // Emit the body of the partial reduction loop.
- llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
- &b_);
- llvm::Value* x =
- NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)),
- tile_element_loop->GetIndVarValue());
- // Unless we know the tile is entirely in bounds, we have to emit a
- // x-in-bounds check before reading from the input.
- if (!tile_in_bounds) {
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_);
-
- // Emit code that reads the input element and accumulates it to
- // the partial reduction result.
- llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
- }
-
- IrArray::Index input_index(
- /*linear=*/x, input_shape, &b_);
- llvm::Value* input_address = Alloca(element_ir_type);
- for (int i = 0; i != num_reduces; ++i) {
- TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
- input_gens[i](input_index));
- Store(input_ir_value, input_address);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducers[i],
- {partial_reduction_result_addresses[i], input_address},
- partial_reduction_result_addresses[i]));
- }
- return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens);
- };
-
- // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's
- // immediately beyond the tile.
- llvm::Value* x_end =
- NSWAdd(index_typed_constant(kTileSize),
- NSWMul(x_in_tiles, index_typed_constant(kTileSize)));
- // The tile is entirely in bound if all_threads_in_bounds or
- // x_end <= num_elems.
- llvm::Value* tile_in_bounds =
- Or(ICmpULE(x_end, index_typed_constant(num_elems)),
- b_.getInt1(all_threads_in_bounds));
- llvm_ir::LlvmIfData if_tile_in_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, &b_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
-
- // After the if-then-else statement on tile_in_bounds, emit calls to
- // shfl_down that accumulate the partial reduction results of all threads
- // from the warp.
- llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &b_);
- int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
- // bitcast cannot be applied to aggregate types (even packed ones), so we
- // instead bitcast addresses of load/store to intN* of the same bit-width.
- llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
- ? b_.getIntNTy(bit_width)
- : element_ir_type;
- for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1;
- shuffle_distance /= 2) {
- llvm::Value* result_from_other_lane =
- Alloca(element_ir_type, nullptr, "result_from_other_lane");
- for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result =
- Load(BitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
- CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
- << "Requires block size a multiple of the warp size, otherwise we "
- "will read undefined elements.";
- Store(EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducers[i],
- {partial_reduction_result_addresses[i], result_from_other_lane},
- partial_reduction_result_addresses[i]));
- }
- }
-
- const HloInstruction* output =
- reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
-
- // Emit an atomic operation that accumulates the partial reduction result of
- // lane 0 (which holds the partially accumulated result for its warp) to the
- // output element.
- llvm::Value* lane_id =
- URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id");
- llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
- llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
-
- for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* output_address =
- GetIrArray(*output, *output, reduce_output_shapes[i])
- .EmitArrayElementAddress(
- IrArray::Index(
- /*linear=*/b_.getInt64(0),
- ShapeUtil::GetSubshape(output->shape(),
- reduce_output_shapes[i]),
- &b_),
- &b_, "output_element_address");
- TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
- *reducers[i], output_address, partial_reduction_result_addresses[i]));
- }
- return Status::OK();
- };
-
- // Emit a parallel loop that iterates through all input tiles, one per thread.
- UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
- ir_emitter_context_->llvm_module());
- return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &b_)
- .EmitLoop(IrName(reduce), index_ty);
-}
-
-Status IrEmitterUnnested::EmitColumnReduction(
- KernelThunk* kernel_thunk, int64 height, int64 width,
- HloInstruction* reduce, const Shape& input_shape,
- absl::Span<const llvm_ir::ElementGenerator> input_gens,
- absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
- absl::Span<HloComputation* const> reducers,
- absl::Span<const ShapeIndex> reduce_output_shapes,
- absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens) {
- // Divide the input matrix into tiles of size KxL. For example, when the
- // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like
- //
- // 0123
- // 0123
- // 4567
- // 4567 // Numbers indicate tile IDs.
- //
- // Each tile is first partially reduced to a scalar by a thread, and then the
- // scalar is accumulated to the output vector using atomic operations.
- //
- // We choose 128 as the tile size based on empirical evidence. It's big enough
- // to reduce the amount of atomic adds in the end, maximizing the memory
- // bandwidth. A tile width of 2 allows for high memory bandwidth utilization
- // on 16b input data.
- constexpr int64 kTileHeight = 128;
- constexpr int64 kTileWidth = 2;
-
- // If the height is not a multiple of kTileHeight, we pad the bottom of the
- // input matrix.
- const int64 height_in_tiles = CeilOfRatio(height, kTileHeight);
- // If width is not a multiple of kTileWidth the rightmost thread will process
- // fewer input elements.
- const int64 width_in_tiles = CeilOfRatio(width, kTileWidth);
- Shape tiled_input_shape =
- ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(),
- {height_in_tiles, width_in_tiles}, {1, 0});
- LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- tiled_input_shape, ir_emitter_context_->device_description());
-
- // TODO(b/110211620): Convert to use i32 index_type when it is possible.
- llvm::Type* index_ty = b_.getInt64Ty();
-
- auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
- return llvm::ConstantInt::get(index_ty, c);
- };
-
- // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
- // linear_index < height_in_tiles * width_in_tiles;
- // linear_index += blockDim.x * gridDim.x) {
- // y_in_tiles = linear_index / width_in_tiles;
- // x_in_tiles = linear_index % width_in_tiles;
- //
- // partial_results[kTileWidth] = init_values;
- // tile_in_y_bounds = height % kTileHeight == 0 ||
- // y_in_tiles * kTileHeight + kTileHeight <= height;
- // tile_in_x_bounds = width % kTileWidth == 0 ||
- // x_in_tiles * kTileWidth + kTileWidth <= width;
- // // The implementation handles y and x bound checks separately.
- // if (tile_in_y_bounds && tile_in_x_bounds) {
- // for (y_offset : range(kTileHeight)) {
- // y = y_in_tiles * kTileHeight + y_offset;
- // for (x_offset : range(kTileWidth)) {
- // x = x_in_tiles * kTileWidth + x_offset;
- // partial_result = Reducer(partial_result[x_offset], input[y][x]);
- // }
- // }
- // } else {
- // for (y_offset : range(kTileHeight)) {
- // y = y_in_tiles * kTileHeight + y_offset;
- // for (y_offset : range(kTileHeight)) {
- // x = x_in_tiles * kTileWidth + x_offset;
- // if (y < height && x < width) {
- // partial_result = Reducer(partial_result, input[y][x]);
- // }
- // }
- // }
- // }
- // for (x_offset : range(kTileWidth)) {
- // AtomicReducer(&output[x + x_offset], partial_result[x_offset]);
- // }
- // }
- auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
- const int num_reduces = reducers.size();
- // Emit the loop body that reduces one tile.
- llvm::Type* element_ir_type =
- llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
- std::vector<llvm::Value*> partial_reduction_result_addresses;
- for (int i = 0; i != num_reduces; ++i) {
- for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* partial_reduction_result_address =
- Alloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." +
- llvm::Twine(i * kTileWidth + x_offset));
- TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
- init_value_gens[i](IrArray::Index(index_ty)));
- Store(init_ir_value, partial_reduction_result_address);
- partial_reduction_result_addresses.push_back(
- partial_reduction_result_address);
- }
- }
-
- // Emit an inner for-loop that partially reduces the elements in the given
- // tile.
- llvm::Value* y_in_tiles = tile_index[0];
- llvm::Value* x_in_tiles = tile_index[1];
-
- y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty);
- x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty);
-
- auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
- bool tile_in_x_bounds) -> Status {
- std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
- llvm_ir::ForLoop::EmitForLoop(
- "element_id_in_tile", index_typed_constant(0),
- index_typed_constant(kTileHeight), index_typed_constant(1), &b_);
-
- // Emit the body of the partial reduction loop.
- llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
- &b_);
- llvm::Value* y =
- NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)),
- tile_element_loop->GetIndVarValue());
-
- // Unless we know that y is in bounds, we have to emit a check before
- // reading from the input.
- if (!tile_in_y_bounds) {
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_);
-
- // Emit code that reads the input element and accumulates it to
- // the partial reduction result.
- llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
- }
- for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x =
- NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
- // Unless we know that x is in bounds, we have to emit a check before
- // reading from the input.
- if (!tile_in_x_bounds) {
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_);
- llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_);
- }
- llvm::Value* input_address = Alloca(element_ir_type);
- // {y,x} is an index to input_matrix_shape [height,width]. We need to
- // convert that to an index to input_shape (the shape of the operand of
- // "reduce"). This conversion is composed of a transposition from
- // input_shape to normalized_input_shape and a reshape from
- // normalized_input_shape to input_matrix_shape.
- const Shape normalized_input_shape =
- ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- input_shape);
- auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape);
- const std::vector<int64> transpose_dimension_mapping(
- input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
-
- const Shape input_matrix_shape =
- ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(),
- {height, width});
- const IrArray::Index input_matrix_index({y, x}, input_matrix_shape,
- &b_);
- const IrArray::Index input_index =
- input_matrix_index
- .SourceIndexOfReshape(input_matrix_shape,
- normalized_input_shape, &b_)
- .SourceIndexOfTranspose(normalized_input_shape, input_shape,
- transpose_dimension_mapping, &b_);
- for (int i = 0; i != num_reduces; ++i) {
- TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
- input_gens[i](input_index));
- Store(input_ir_value, input_address);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducers[i],
- {partial_reduction_result_addresses[i * kTileWidth + x_offset],
- input_address},
- partial_reduction_result_addresses[i * kTileWidth + x_offset]));
- TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index,
- extra_output_gens));
- }
- }
- return Status::OK();
- };
-
- // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
- // that's immediately beyond the tile.
- llvm::Value* y_end =
- NSWAdd(index_typed_constant(kTileHeight),
- NSWMul(y_in_tiles, index_typed_constant(kTileHeight)));
- // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
- // that's immediately beyond the tile.
- llvm::Value* x_end =
- NSWAdd(index_typed_constant(kTileWidth),
- NSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
- llvm::Value* tile_in_y_bounds =
- Or(ICmpULE(y_end, index_typed_constant(height)),
- b_.getInt1(height % kTileHeight == 0));
- llvm::Value* tile_in_x_bounds =
- Or(ICmpULE(x_end, index_typed_constant(width)),
- b_.getInt1(width % kTileWidth == 0));
- // The tile is in y bounds if "height" is a multiple of kTileHeight or
- // y_end <= height.
- llvm_ir::LlvmIfData if_tile_in_y_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_y_bounds, "tile_in_y_bounds", &b_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &b_);
- // The tile is in x bounds if "width" is a multiple of kTileWidth or
- // x_end <= width.
- llvm_ir::LlvmIfData if_tile_in_x_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
- /*tile_in_x_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
- /*tile_in_x_bounds=*/false));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, &b_);
- if_tile_in_x_bounds_data =
- llvm_ir::EmitIfThenElse(tile_in_x_bounds, "tile_in_x_bounds", &b_);
- llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &b_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
- /*tile_in_x_bounds=*/true));
- llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, &b_);
- TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
- /*tile_in_x_bounds=*/false));
-
- // After the nested if-then-else statement on tile_in_y_bounds and
- // tile_in_x_bounds, emit atomic operations to accumulate the partial
- // reduction result to the output element.
- llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &b_);
- const HloInstruction* output =
- reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
- for (int i = 0; i != num_reduces; ++i) {
- for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
- llvm::Value* x =
- NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)),
- index_typed_constant(x_offset));
- llvm::Value* output_address =
- GetIrArray(*output, *output, reduce_output_shapes[i])
- .EmitArrayElementAddress(
- IrArray::Index(
- x,
- ShapeUtil::GetSubshape(output->shape(),
- reduce_output_shapes[i]),
- &b_),
- &b_, "output_element_address");
- TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
- *reducers[i], output_address,
- partial_reduction_result_addresses[i * kTileWidth + x_offset]));
- }
- }
- return Status::OK();
- };
-
- // Emit a parallel loop that iterate through all input tiles.
- UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
- ir_emitter_context_->llvm_module());
- return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &b_)
- .EmitLoop(IrName(reduce), index_ty);
-}
-
-static std::pair<int64, int64> ComputeKernelMappingSchemeForReduction(
- int64 depth, int64 width, int64 kWarpSize) {
- constexpr int64 kTargetNumElementsPerThread = 64;
- int64 x_tile_size = kTargetNumElementsPerThread;
- int64 z_tile_size = 1;
-
- // Only tile along the x dimension with tile size kTargetNumElementsPerThread
- // if doing so doesn't require a slow version of loop with bound check on each
- // dimension. A more sophisticated heuristics is to enable tile along the
- // x dimension with tile size kTargetNumElementsPerThread when either width is
- // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big
- // enough so that only a small fraction of the threads execute the slow
- // version of loop with bound check.
- if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) {
- x_tile_size = 8;
- z_tile_size = 8;
- while (depth % z_tile_size != 0) {
- z_tile_size -= 1;
- }
- }
-
- return std::pair<int64, int64>(x_tile_size, z_tile_size);
-}
-
-Status IrEmitterUnnested::EmitRowReduction(
- KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width,
- HloInstruction* reduce, const Shape& input_shape,
- absl::Span<const llvm_ir::ElementGenerator> input_gens,
- absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
- absl::Span<HloComputation* const> reducers,
- absl::Span<const ShapeIndex> reduce_output_shapes,
- absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens) {
- // A naive algorithm is:
- // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX.
- // 2. Partially reduces each tile to a scalar using one thread.
- // 3. Accumulates that scalar to the output vector using atomic operations.
- //
- // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
- // linear_index < depth * height * width_in_tiles;
- // linear_index += blockDim.x * gridDim.x) {
- // int x_in_tiles = linear_index % width_in_tiles;
- // int y = linear_index / width_in_tiles % height;
- // int z = linear_index / (height * width_in_tiles);
- // float partial_result = 0;
- // for (element_id_in_tile : range(x_tile_size)) {
- // int x = x_in_tiles * x_tile_size + element_id_in_tile;
- // if (x < width)
- // partial_result = reducer(partial_result, input[z][y][x]);
- // }
- // AtomicReducer(&output[y], partial_result);
- // }
- //
- // Four optimizations are performed.
- //
- // 1. To coalesce global memory accesses, dilate the tile with a factor of 32
- // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead
- // of making each tile consecutive, we let make tile 0 column
- // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures
- // that threads in a warp access consecutive memory in one iteration (i.e.
- // coalesced). In the above example, the warp that contains thread 0-31
- // accesses column 0-31 in the first iteration, and 32-63 in the second
- // iteration, and so on.
- //
- // 2. Partially accumulate partial reduced results computed by threads in the
- // same warp using shfl_down. Using shfl_down is faster than directly using
- // atomic operations because shfl_down transfers the data between threads
- // using shared memory and threads in the same warp run in lock step (thus no
- // extra synchronization needed). See
- // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/
- // for details. The downside is, to produce correct results when using
- // shfl_down, we need to guarantee threads in the same warp work on input
- // elements with the same y, so the number of tiles in each row must be a
- // multiple of 32.
- //
- // 3. Specialize the case that the entire tile is in bounds. When that is
- // true, we don't need to emit "if(x<width)" inside the loop on
- // element_id_in_tile, which makes the code more friendly to optimizations
- // such as LICM.
- //
- // 4. When the width is too small and x_tile_size is less than the target
- // number of elements per thread and use a small factor of depth as
- // z_tile_size to increase the number of elements calculated by each
- // partial sum. This can reduce the needed number of dynamic shfl_down and
- // atomic operations.
- //
- // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
- // linear_index < depth * height * width_in_tiles;
- // linear_index += blockDim.x * gridDim.x) {
- // int x_in_tiles = linear_index % width_in_tiles;
- // int y = linear_index / width_in_tiles % height;
- // int z_in_tiles = linear_index / (height * width_in_tiles);
- // int warp_id = x_in_tiles / warpSize;
- // int lane_id = x_in_tiles % warpSize;
- // float partial_result = 0;
- // int x = warp_id * kTileSize * warpSize + lane_id;
- // if (width % (x_tile_size * warpSize) == 0 ||
- // x + (x_tile_size - 1) * warpSize < width) {
- // // The entire x_tile is in bounds.
- // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size;
- // ++element_id_in_z_tile) {
- // z = z_in_tiles * z_tile_size + element_id_in_z_tile;
- // int tx = x;
- // for (int element_id_in_x_tile = 0;
- // element_id_in_x_tile < x_tile_size;
- // ++element_id_in_x_tile, tx += warpSize) {
- // partial_result = Reducer(partial_result, input[z][y][tx]);
- // }
- // }
- // } else {
- // // The tile is partially in bounds.
- // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size;
- // ++element_id_in_z_tile) {
- // z = z_in_tiles * z_tile_size + element_id_in_z_tile;
- // int tx = x;
- // for (int element_id_in_x_tile = 0; element_id_in_x_tile <
- // x_tile_size; ++element_id_in_tile, tx += warpSize) {
- // if (tx < width)
- // partial_result = Reducer(partial_result, input[z][y][tx]);
- // }
- // }
- // }
- // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2)
- // partial_result = Reducer(
- // partial_result,
- // __shfl_down_sync(CUDA_WARP_ALL, partial_result, shuffle_distance));
- // if (lane_id == 0)
- // AtomicReducer(&output[y], partial_result);
- // }
- //
-
- int64 x_tile_size;
- int64 z_tile_size;
- std::tie(x_tile_size, z_tile_size) =
- ComputeKernelMappingSchemeForReduction(depth, width, kWarpSize);
-
- // Round the width in tiles up to the nearest multiple of kWarpSize, so that
- // the use of shfl_down is valid.
- const int64 width_in_tiles =
- RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize);
- Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
- reduce->shape().element_type(),
- {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0});
- LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
- tiled_input_shape, ir_emitter_context_->device_description());
- llvm::Type* index_ty =
- GetIndexTypeForKernel(reduce, launch_dimensions.launch_bound(), &b_);
-
- auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
- return llvm::ConstantInt::get(index_ty, c);
- };
-
- auto loop_body_emitter = [=](const IrArray::Index& tile_index) {
- const int num_reduces = reducers.size();
- llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
- input_shape.element_type(), ir_emitter_context_->llvm_module());
- std::vector<llvm::Value*> partial_reduction_result_addresses;
- for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result_address =
- Alloca(element_ir_type, /*ArraySize=*/nullptr,
- "partial_reduction_result." + llvm::Twine(i));
- TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
- init_value_gens[i](IrArray::Index(index_ty)));
- Store(init_ir_value, partial_reduction_result_address);
- partial_reduction_result_addresses.push_back(
- partial_reduction_result_address);
- }
-
- llvm::Value* z_tile = tile_index[0];
- llvm::Value* y = tile_index[1];
- llvm::Value* x_tile = tile_index[2];
-
- x_tile = ZExtOrTrunc(x_tile, index_ty);
-
- llvm::Value* warp_id =
- UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id");
- llvm::Value* lane_id =
- URem(x_tile, index_typed_constant(kWarpSize), "lane_id");
-
- // The x-location of the last element in this z-x-tile.
- // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size);
- llvm::Value* last_x = NSWAdd(
- lane_id,
- NSWMul(index_typed_constant(kWarpSize),
- NSWAdd(index_typed_constant(x_tile_size - 1),
- NSWMul(warp_id, index_typed_constant(x_tile_size)))));
-
- KernelSupportLibrary ksl(
- &b_,
- /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll,
- /*prevent_vectorization=*/false);
-
- // Emit a for-loop that partially reduces the elements in the given
- // z-x-tile.
- auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds,
- int64 x_tile_loop_bound) -> Status {
- auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status {
- llvm::Value* z =
- NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile));
- TF_RETURN_IF_ERROR(ksl.For(
- "x_tile",
- /*start=*/index_typed_constant(0),
- /*end=*/index_typed_constant(x_tile_loop_bound),
- /*step=*/1, [&](llvm::Value* x_indvar) -> Status {
- // x = lane_id +
- // warpSize * (element_id_in_x_tile + warp_id * x_tile_size);
- llvm::Value* x = NSWAdd(
- lane_id,
- NSWMul(index_typed_constant(kWarpSize),
- NSWAdd(x_indvar,
- NSWMul(warp_id, llvm::ConstantInt::get(
- index_ty, x_tile_size)))));
-
- // Unless we know the x-tile is entirely in bounds, we have to
- // emit a x-in-bounds check before reading from the input.
- if (!x_tile_in_bounds) {
- llvm_ir::LlvmIfData if_x_in_bounds_data =
- llvm_ir::EmitIfThenElse(
- ICmpULT(x, index_typed_constant(width)), "x_in_bounds",
- &b_);
- // Points b_ to the then-block.
- llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
- &b_);
- }
-
- // Emit code that reads the input element and accumulates it
- // to the partial reduction result.
- llvm::Value* input_address = Alloca(element_ir_type);
- {
- // {z,y,x} is an index to input_3d_tensor_shape
- // [depth,height,width]. We need to convert that to an index
- // to input_shape (the shape of the operand of "reduce").
- // This conversion is composed of a transposition from
- // input_shape to normalized_input_shape and a reshape from
- // normalized_input_shape to input_3d_tensor_shape.
- const Shape normalized_input_shape = ShapeUtil::
- MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
- input_shape);
- auto input_shape_min2maj =
- LayoutUtil::MinorToMajor(input_shape);
- const std::vector<int64> transpose_dimension_mapping(
- input_shape_min2maj.rbegin(), input_shape_min2maj.rend());
- const Shape input_3d_tensor_shape =
- ShapeUtil::MakeShapeWithDescendingLayout(
- input_shape.element_type(), {depth, height, width});
- const IrArray::Index input_3d_tensor_index(
- {z, y, x}, input_3d_tensor_shape, &b_);
- const IrArray::Index input_index =
- input_3d_tensor_index
- .SourceIndexOfReshape(input_3d_tensor_shape,
- normalized_input_shape, &b_)
- .SourceIndexOfTranspose(
- normalized_input_shape, input_shape,
- transpose_dimension_mapping, &b_);
-
- for (int i = 0; i != num_reduces; ++i) {
- TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
- input_gens[i](input_index));
- Store(input_ir_value, input_address);
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducers[i],
- {partial_reduction_result_addresses[i], input_address},
- partial_reduction_result_addresses[i]));
- }
- return EmitExtraOutputsForReduce(reduce, input_index,
- extra_output_gens);
- }
- }));
- return Status::OK();
- };
-
- return ksl.For("z_tile",
- /*start=*/index_typed_constant(0),
- /*end=*/index_typed_constant(z_tile_size),
- /*step=*/1, emit_z_tile_element_loop);
- };
-
- llvm::Value* tile_in_bounds =
- Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0),
- ICmpULT(last_x, index_typed_constant(width)));
-
- TF_RETURN_IF_ERROR(
- ksl.If(tile_in_bounds,
- /*true_block_generator=*/
- [&]() -> Status {
- return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true,
- x_tile_size);
- },
- /*false_block_generator=*/
- [&]() -> Status {
- return emit_z_x_tile_element_loop(
- /*x_tile_in_bounds=*/false,
- CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize));
- }));
-
- // After accumulating the elements of the z_x_tile, emit calls to
- // shfl_down that accumulate the partial reduction results of all
- // threads in a warp.
- int bit_width = llvm_ir::GetSizeInBits(element_ir_type);
- // bitcast cannot be applied to aggregate types (even packed ones), so we
- // instead bitcast addresses of load/store to intN* of the same bit-width.
- llvm::Type* shuffle_ir_type = element_ir_type->isStructTy()
- ? b_.getIntNTy(bit_width)
- : element_ir_type;
- for (int shuffle_distance = 16; shuffle_distance >= 1;
- shuffle_distance /= 2) {
- llvm::Value* result_from_other_lane =
- Alloca(element_ir_type, nullptr, "result_from_other_lane");
- for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* partial_reduction_result =
- Load(BitCast(partial_reduction_result_addresses[i],
- shuffle_ir_type->getPointerTo()),
- "partial_reduction_result");
- CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0)
- << "Requires block size a multiple of the warp size, otherwise we "
- "will read undefined elements.";
- Store(EmitFullWarpShuffleDown(partial_reduction_result,
- b_.getInt32(shuffle_distance), &b_),
- BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo()));
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducers[i],
- {partial_reduction_result_addresses[i], result_from_other_lane},
- partial_reduction_result_addresses[i]));
- }
- }
-
- const HloInstruction* output =
- reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
-
- // Emit an atomic operation that accumulates the partial reduction result of
- // lane 0 (which holds the partially accumulated result for its warp) to the
- // output element.
- llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
- ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_);
- llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
- for (int i = 0; i != num_reduces; ++i) {
- llvm::Value* output_address =
- GetIrArray(*output, *output, reduce_output_shapes[i])
- .EmitArrayElementAddress(
- IrArray::Index(y,
- ShapeUtil::GetSubshape(
- output->shape(), reduce_output_shapes[i]),
- &b_),
- &b_, "output_element_address");
- // We don't need to emit atomic operations if there is only one tile of
- // results. 'depth' is the z dimension, 'width' is the x dimension.
- if (z_tile_size >= depth && x_tile_size >= width) {
- TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
- *reducers[i],
- {output_address, partial_reduction_result_addresses[i]},
- output_address));
- } else {
- TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
- *reducers[i], output_address,
- partial_reduction_result_addresses[i]));
- }
- }
- return Status::OK();
- };
-
- // Emit a parallel loop that iterates through every input tiles.
- UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
- ir_emitter_context_->llvm_module());
- return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
- launch_dimensions, &b_)
- .EmitLoop(IrName(reduce), index_ty);
-}
-
-// Figures out whether `reduce` is a row or column reduction, and which
-// dimensions to reduce, and calls either `EmitRowReduction` or
-// `EmitColumnReduction` as appropriate.
-// Prerequisite: all the dimensions to keep are contiguous in the input layout
-// and, if `reduce` is fused, the fused subgraph is pure
-// elementwise.
-Status IrEmitterUnnested::EmitReductionToVector(
- KernelThunk* kernel_thunk, HloInstruction* reduce, const Shape& input_shape,
- absl::Span<const llvm_ir::ElementGenerator> input_gens,
- absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
- absl::Span<const int64> dimensions_to_reduce,
- absl::Span<HloComputation* const> reducers,
- absl::Span<const ShapeIndex> reduce_output_shapes,
- absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens) {
- // This emission requires "reduce" to have an input layout. It is either set
- // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
- // a fused kReduce).
- CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
- "doesn't set the input layout of "
- << reduce->ToString();
-
- // Specialize multi-dimensional-array-to-vector reduction.
- std::vector<int64> input_dims_to_keep;
- for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
- ++input_dim) {
- if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(),
- input_dim) == dimensions_to_reduce.end()) {
- input_dims_to_keep.push_back(input_dim);
- }
- }
-
- // Sort the dimensions to keep from minor to major, to facilitate checking
- // whether another dimension is major or minor of them.
- std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(),
- [&input_shape](int64 dim_a, int64 dim_b) {
- return PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
- dim_a) <
- PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
- dim_b);
- });
- // Now, if output rank is at least 1, `input_dims_to_keep.front()` is
- // minormost and `input_dims_to_keep.back()` is majormost.
-
- // If the dimensions to keep are minormost, emit a column reduction. As all
- // the dimensions to keep are contiguous, by prerequisite of
- // `EmitReductionToVector`, we only need to check whether the minormost
- // dimension of the input is to keep.
- if (ShapeUtil::IsEffectiveScalar(reduce->shape())) {
- return EmitReductionToScalar(kernel_thunk, reduce, input_shape, input_gens,
- init_value_gens, reducers,
- reduce_output_shapes, extra_output_gens);
- } else if (input_dims_to_keep.front() ==
- LayoutUtil::Minor(input_shape.layout(), 0)) {
- // Column reduction. Treat the result of "input" as a matrix whose width
- // is the most minor dimension and height the product of other dimensions,
- // and treat "reduce" as a column reduction of the input matrix.
- const int64 width = ShapeUtil::ElementsIn(reduce->shape());
- // "width" can be zero, so don't do
- // height = ShapeUtil::ElementsIn(input_shape) / width;
- int64 height = 1;
- for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
- ++input_dim) {
- if (!std::count(input_dims_to_keep.begin(), input_dims_to_keep.end(),
- input_dim)) {
- height *= input_shape.dimensions(input_dim);
- }
- }
- return EmitColumnReduction(kernel_thunk, height, width, reduce, input_shape,
- input_gens, init_value_gens, reducers,
- reduce_output_shapes, extra_output_gens);
- } else {
- // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a
- // 3D tensor. The size of dimension 1 (the height) is the size of the
- // dimension to keep, the size of dimension 0 (the depth) is the product
- // of dimensions that are more major than the dimension to keep, and the
- // size of dimension 2 (the width) is the product of more minor
- // dimensions.
- int64 depth = 1;
- int64 width = 1;
- for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
- ++input_dim) {
- if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
- input_dim) >
- PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
- input_dims_to_keep.back())) {
- depth *= input_shape.dimensions(input_dim);
- } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
- input_dim) <
- PositionInContainer(LayoutUtil::MinorToMajor(input_shape),
- input_dims_to_keep.front())) {
- width *= input_shape.dimensions(input_dim);
- }
- }
- const int64 height = ShapeUtil::ElementsIn(reduce->shape());
- return EmitRowReduction(kernel_thunk, depth, height, width, reduce,
- input_shape, input_gens, init_value_gens, reducers,
- reduce_output_shapes, extra_output_gens);
- }
-}
-
Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
// TODO(b/112040122): Support multi-output reduce.
if (!ShapeUtil::IsArray(reduce->shape())) {
return Unimplemented("Multi-output reduce is not supported on GPU");
}
- auto input = reduce->operand(0);
- auto init_value = reduce->operand(1);
- absl::Span<const int64> dimensions_to_reduce(reduce->dimensions());
- HloComputation* reducer = reduce->to_apply();
- // HandleReduce specializes reduction from a multi-dimensional array to a 1D
- // array. The specialized version requires an initializer thunk that
- // initializes the output array to the initial value of the reduce.
if (IsReductionToVector(*reduce)) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
- BuildInitializerThunk(reduce));
- std::vector<std::unique_ptr<Thunk>> thunks;
- thunks.push_back(std::move(initializer_thunk));
- std::unique_ptr<KernelThunk> kernel_thunk =
- BuildKernelThunk(reduce, /*implements_whole_instruction=*/false);
-
- TF_CHECK_OK(EmitReductionToVector(
- kernel_thunk.get(), reduce, input->shape(),
- {[&](const IrArray::Index& index) {
- return GetIrArray(*input, *reduce).EmitReadArrayElement(index, &b_);
- }},
- {[&](const IrArray::Index& index) {
- return GetIrArray(*init_value, *reduce)
- .EmitReadArrayElement(index, &b_);
- }},
- dimensions_to_reduce, {reducer}, {{}}, {}));
-
- thunks.push_back(std::move(kernel_thunk));
-
- std::unique_ptr<SequentialThunk> sequential_thunk =
- absl::make_unique<SequentialThunk>(std::move(thunks), reduce);
- AddThunkToThunkSequence(std::move(sequential_thunk));
- return Status::OK();
+ return EmitReductionToVector(reduce);
}
return IrEmitter::HandleReduce(reduce);
@@ -1819,7 +763,7 @@
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
index_type);
- std::vector<int64> window_size;
+ DimensionVector window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
CHECK_GT(dim.size(), 0);
@@ -2172,7 +1116,18 @@
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
std::vector<std::unique_ptr<Thunk>> thunks;
Shape keys_shape = sort->operand(0)->shape();
+ int64 dimension_to_sort = sort->dimensions(0);
+ // In case there is a 'values' parameter that is a iota, we take note and use
+ // it later to ensure a stable sort. Otherwise, we don't guarantee a stable
+ // sort.
+ int64 iota_values_parameter_index = -1;
for (int64 i = 0; i < sort->operand_count(); ++i) {
+ if (i > 0 && sort->operand(i)->opcode() == HloOpcode::kIota &&
+ ShapeUtil::ElementIsIntegral(sort->operand(i)->shape()) &&
+ Cast<HloIotaInstruction>(sort->operand(i))->iota_dimension() ==
+ dimension_to_sort) {
+ iota_values_parameter_index = i;
+ }
ShapeIndex shape_index =
sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
// We assume that the layout of all involved operands and outputs is the
@@ -2197,7 +1152,6 @@
}
}
- int64 dimension_to_sort = sort->dimensions(0);
uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
@@ -2299,8 +1253,9 @@
}
}
return llvm_ir::EmitSortInPlace(
- dimension_to_sort, keys_array, values_arrays, IrName(sort), xor_masks,
- &b_, launch_dimensions,
+ dimension_to_sort, keys_array, values_arrays,
+ iota_values_parameter_index, IrName(sort), xor_masks, &b_,
+ launch_dimensions,
xor_masks.size() > 1 ? num_iterations_in_sort_dim
: standard_num_iterations_in_sort_dim,
kTileSize);
@@ -2386,7 +1341,7 @@
return Status::OK();
}
-Status IrEmitterUnnested::HandleAfterAll(HloInstruction* gen_token) {
+Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
return Status::OK();
}
@@ -3253,7 +2208,8 @@
builder->CreateAdd(llvm::ConstantInt::get(index_ty, j), x);
ksl->IfReturnVoid(
- "x_in_tile", builder->CreateICmpULT(x_loc, tile_width), [&] {
+ loop_name + "_x_in_tile", builder->CreateICmpULT(x_loc, tile_width),
+ [&] {
// tile_height_bound =
// ceil(tile_height / num_threads_y) * num_threads_y
llvm::Value* ceiling_of_ratio = builder->CreateUDiv(
@@ -3270,8 +2226,8 @@
[&](llvm::Value* y_indvar) {
llvm::Value* y_loc = builder->CreateAdd(y_indvar, y);
ksl->IfReturnVoid(
- "y_in_tile", builder->CreateICmpULT(y_loc, tile_height),
- [&] {
+ loop_name + "_y_in_tile",
+ builder->CreateICmpULT(y_loc, tile_height), [&] {
emit_elem_function(
source_idx.AddOffsetToDim(
y_indvar, KernelMappingScheme::DimY, builder),
@@ -3302,7 +2258,7 @@
llvm::Type* index_ty = tile_width->getType();
ksl->IfReturnVoid(
- "full_tile",
+ loop_name + "_full_tile",
builder->CreateAnd(
builder->CreateICmpEQ(llvm::ConstantInt::get(index_ty, tile_size_x),
tile_width),
@@ -3393,7 +2349,395 @@
}
}
-// Emits a block of tiles, given a function object to emit one tile.
+// Information to support the code generation for a tiled reduction kernel.
+using AddressVector = InlinedVector<llvm::AllocaInst*, 1>;
+class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo {
+ public:
+ explicit ReductionCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme,
+ bool is_row_reduction)
+ : KernelCodegenInfo(mapping_scheme),
+ current_output_linear_index_address_(nullptr),
+ current_output_inbound_address_(nullptr),
+ is_row_reduction_(is_row_reduction) {}
+
+ void SetCurrentOutputLinearIndexAddress(llvm::AllocaInst* a) {
+ current_output_linear_index_address_ = a;
+ }
+ // Returns the address of the memory that stores the linear index of the
+ // current output. Since we are processing reduction to contiguous physical
+ // dimensions, this linear index is the linear index of the 1D output array.
+ llvm::AllocaInst* GetCurrentOutputLinearIndexAddress() const {
+ return current_output_linear_index_address_;
+ }
+
+ void SetCurrentOutputInboundAddress(llvm::AllocaInst* a) {
+ current_output_inbound_address_ = a;
+ }
+
+ llvm::AllocaInst* GetCurrentOutputInboundAddress() const {
+ return current_output_inbound_address_;
+ }
+
+ AddressVector* GetMutablePartialResultAddresses() {
+ return &partial_result_addresses_;
+ }
+ const AddressVector& GetPartialResultAddresses() const {
+ return partial_result_addresses_;
+ }
+
+ AddressVector* GetMutableReductionInputAddresses() {
+ return &reduction_input_addresses_;
+ }
+ const AddressVector& GetReductionInputAddresses() const {
+ return reduction_input_addresses_;
+ }
+
+ InlinedVector<HloComputation*, 1>* GetMutableReducers() { return &reducers_; }
+ const InlinedVector<HloComputation*, 1>& GetReducers() const {
+ return reducers_;
+ }
+ int GetNumberOfReduces() const { return reducers_.size(); }
+
+ InlinedVector<ShapeIndex, 1>* GetMutableReductionOutputShapeIndices() {
+ return &reduction_output_shape_indices_;
+ }
+ const InlinedVector<ShapeIndex, 1>& GetReductionOutputShapeIndices() const {
+ return reduction_output_shape_indices_;
+ }
+
+ bool IsRowReduction() const { return is_row_reduction_; }
+
+ // Return the dimension that is being reduced between DimX and DimY.
+ int GetReducedDimensionEnum() const {
+ return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimX
+ : llvm_ir::KernelMappingScheme::DimY;
+ }
+
+ // Return the dimension that is being ketp between DimX and DimY.
+ int GetKeptDimensionEnum() const {
+ return IsRowReduction() ? llvm_ir::KernelMappingScheme::DimY
+ : llvm_ir::KernelMappingScheme::DimX;
+ }
+
+ private:
+ AddressVector partial_result_addresses_;
+ AddressVector reduction_input_addresses_;
+ InlinedVector<HloComputation*, 1> reducers_;
+ InlinedVector<ShapeIndex, 1> reduction_output_shape_indices_;
+ llvm::AllocaInst* current_output_linear_index_address_;
+ llvm::AllocaInst* current_output_inbound_address_;
+ bool is_row_reduction_;
+};
+
+namespace {
+// Returns a group of instructions that generate the output for the kernel
+// containing the given HLO instruction. The result may be an unnested kReduce
+// HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple
+// for a multiple output fusion.
+absl::Span<HloInstruction* const> GetOutputInstructions(
+ HloInstruction* const* reduce_or_tuple_pointer) {
+ HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode();
+ CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple);
+ return opcode == HloOpcode::kTuple
+ ? (*reduce_or_tuple_pointer)->operands()
+ : absl::Span<HloInstruction* const>(reduce_or_tuple_pointer, 1);
+}
+
+const HloInstruction* GetFirstReduceInstruction(
+ absl::Span<HloInstruction* const> instructions) {
+ auto first_reduce_iter =
+ absl::c_find_if(instructions, [](const HloInstruction* inst) {
+ return inst->opcode() == HloOpcode::kReduce;
+ });
+ CHECK_NE(first_reduce_iter, instructions.end());
+ return *first_reduce_iter;
+}
+
+}; // namespace
+
+void IrEmitterUnnested::EmitPrologueForOneReduction(
+ HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx,
+ KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter,
+ ShapeIndex output_shape_index) {
+ ReductionCodegenInfo* reduction_info =
+ static_cast<ReductionCodegenInfo*>(kernel_info);
+
+ InlinedVector<HloComputation*, 1>* reducers =
+ reduction_info->GetMutableReducers();
+ CHECK(IsReductionToVector(*reduce_inst));
+ reducers->push_back(reduce_inst->to_apply());
+
+ InlinedVector<ShapeIndex, 1>* reduction_output_shape_indices =
+ reduction_info->GetMutableReductionOutputShapeIndices();
+ reduction_output_shape_indices->push_back(std::move(output_shape_index));
+
+ AddressVector* reduction_input_addresses =
+ reduction_info->GetMutableReductionInputAddresses();
+ llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
+ reduce_inst->shape().element_type(), ir_emitter_context_->llvm_module());
+ llvm::AllocaInst* reduction_input_address = Alloca(element_type);
+ reduction_input_addresses->push_back(reduction_input_address);
+
+ AddressVector* partial_result_addresses =
+ reduction_info->GetMutablePartialResultAddresses();
+ llvm::AllocaInst* partial_result_address =
+ Alloca(element_type, /*ArraySize=*/nullptr,
+ "partial_reduction_result." + llvm::Twine(reduce_idx));
+ partial_result_addresses->push_back(partial_result_address);
+
+ // Initialize the partial result with the initial value of the reduction.
+ llvm::Value* init_ir_value;
+ if (unnested_hlo->opcode() == HloOpcode::kFusion) {
+ HloInstruction* init_value_operand = reduce_inst->mutable_operand(1);
+ FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
+ elemental_emitter);
+
+ TF_CHECK_OK(init_value_operand->Accept(&fused_emitter));
+ init_ir_value =
+ fused_emitter
+ .GetGenerator(init_value_operand)(IrArray::Index(b_.getInt32Ty()))
+ .ValueOrDie();
+ } else {
+ const HloInstruction* init_value = unnested_hlo->operand(1);
+ init_ir_value =
+ GetIrArray(*init_value, *unnested_hlo)
+ .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_);
+ }
+
+ Store(init_ir_value, partial_result_address);
+}
+
+void IrEmitterUnnested::EmitPrologueForReduction(
+ HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) {
+ VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString();
+ // Find the unnested kReduce or the tuple that contains a list of kReduce.
+ HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
+ ? unnested_hlo->fused_expression_root()
+ : unnested_hlo;
+ absl::Span<HloInstruction* const> output_instructions =
+ GetOutputInstructions(&reduce_or_tuple);
+ ReductionCodegenInfo* reduction_info =
+ static_cast<ReductionCodegenInfo*>(kernel_info);
+ GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
+ ir_emitter_context_->llvm_module(),
+ &b_, GetNestedComputer());
+ const HloInstruction* first_reduce = nullptr;
+ for (int i = 0, e = output_instructions.size(); i != e; ++i) {
+ if (output_instructions[i]->opcode() != HloOpcode::kReduce) {
+ continue;
+ }
+ HloInstruction* reduce_inst = output_instructions[i];
+ if (first_reduce == nullptr) {
+ first_reduce = reduce_inst;
+ } else {
+ CHECK(first_reduce->dimensions() == reduce_inst->dimensions());
+ }
+ ShapeIndex output_shape_index;
+ if (reduce_or_tuple->opcode() == HloOpcode::kTuple) {
+ output_shape_index = {i};
+ }
+
+ EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info,
+ &elemental_emitter,
+ std::move(output_shape_index));
+ }
+
+ // Allocate stack storage to store the current output linear index and record
+ // the address of the storage.
+ reduction_info->SetCurrentOutputLinearIndexAddress(
+ Alloca(reduction_info->GetIndexType()));
+
+ if (!reduction_info->IsRowReduction()) {
+ llvm::Type* bool_ty = b_.getInt1Ty();
+ llvm::AllocaInst* output_inbound_addr = Alloca(bool_ty);
+ Store(llvm::ConstantInt::get(bool_ty, 0), output_inbound_addr);
+ reduction_info->SetCurrentOutputInboundAddress(output_inbound_addr);
+ }
+}
+
+void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
+ const InlinedVector<HloComputation*, 1>& reducers,
+ const AddressVector& partial_result_addresses) {
+ for (int distance = 16; distance >= 1; distance /= 2) {
+ for (int i = 0; i != reducers.size(); ++i) {
+ llvm::Type* element_type =
+ partial_result_addresses[i]->getType()->getElementType();
+ int bit_width = llvm_ir::GetSizeInBits(element_type);
+ llvm::Value* result_from_other_lane = Alloca(
+ element_type, nullptr, "result_from_other_lane" + llvm::Twine(i));
+ // Bitcast cannot be applied to aggregate types (even packed ones), so
+ // we bitcast addresses of load/store to intN* of the same bit-width.
+ llvm::Type* shuffled_value_type =
+ element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
+ auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
+ return BitCast(ptr, shuffled_value_type->getPointerTo());
+ };
+ llvm::Value* partial_result =
+ Load(convert_pointer_for_shuffle(partial_result_addresses[i]),
+ "partial_reduction_result");
+ Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
+ convert_pointer_for_shuffle(result_from_other_lane));
+ TF_CHECK_OK(EmitCallToNestedComputation(
+ *reducers[i], {partial_result_addresses[i], result_from_other_lane},
+ partial_result_addresses[i]));
+ }
+ }
+}
+
+void IrEmitterUnnested::EmitEpilogueForReduction(
+ HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) {
+ ReductionCodegenInfo* reduction_info =
+ static_cast<ReductionCodegenInfo*>(kernel_info);
+ int num_reduces = reduction_info->GetNumberOfReduces();
+ const AddressVector& partial_result_addresses =
+ reduction_info->GetPartialResultAddresses();
+ const InlinedVector<HloComputation*, 1>& reducers =
+ reduction_info->GetReducers();
+ const InlinedVector<ShapeIndex, 1>& reduction_output_shape_indices =
+ reduction_info->GetReductionOutputShapeIndices();
+
+ if (reduction_info->IsRowReduction()) {
+ EmitFullWarpShuffleDownLoopForAllReduces(reducers,
+ partial_result_addresses);
+ llvm::Value* lane_id = reduction_info->GetLaneId();
+ llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
+ ICmpEQ(lane_id, llvm::ConstantInt::get(lane_id->getType(), 0)),
+ "lane_id_is_zero", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
+ } else {
+ llvm::Value* output_inbound_addr =
+ reduction_info->GetCurrentOutputInboundAddress();
+ llvm::Value* output_inbound = Load(output_inbound_addr);
+ llvm_ir::LlvmIfData if_output_inbound_data = llvm_ir::EmitIfThenElse(
+ ICmpEQ(output_inbound,
+ llvm::ConstantInt::get(output_inbound->getType(), 1)),
+ "output_inbound", &b_);
+ llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_);
+ }
+
+ // Emit an atomic operation that accumulates the partial reduction to the
+ // output element. For row reduction, this is only for lane 0 due to the
+ // if-statement emitted above.
+ for (int i = 0; i != num_reduces; ++i) {
+ IrArray::Index element_index(
+ /*linear=*/Load(reduction_info->GetCurrentOutputLinearIndexAddress(),
+ "output_linear_addr"),
+ ShapeUtil::GetSubshape(unnested_hlo->shape(),
+ reduction_output_shape_indices[i]),
+ &b_);
+ llvm::Value* output_address =
+ GetIrArray(*unnested_hlo, *unnested_hlo,
+ reduction_output_shape_indices[i])
+ .EmitArrayElementAddress(element_index, &b_,
+ "output_element_address");
+ // Do not emit atomic operations if each element in the reduction result is
+ // computed by one block, that is the dimension being reduced has only one
+ // block.
+ const llvm_ir::KernelMappingScheme* mapping_scheme =
+ reduction_info->GetKernelMappingScheme();
+ if (mapping_scheme->GetTileBlockSizeForDimension(
+ llvm_ir::KernelMappingScheme::DimZ) == 1 &&
+ mapping_scheme->GetTileBlockSizeForDimension(
+ reduction_info->GetReducedDimensionEnum()) == 1) {
+ TF_CHECK_OK(EmitCallToNestedComputation(
+ *reducers[i], {output_address, partial_result_addresses[i]},
+ output_address));
+ } else {
+ TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
+ *reducers[i], output_address, partial_result_addresses[i]));
+ }
+ }
+}
+
+void IrEmitterUnnested::EmitTileElementForReduction(
+ HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index,
+ const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
+ llvm::Value* x_loc) {
+ VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString();
+ HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
+ ? unnested_hlo->fused_expression_root()
+ : unnested_hlo;
+ llvm_ir::TiledParameterInfo* tiled_param_info =
+ kernel_info->GetTiledParameterInfo();
+ tiled_param_info->set_y(y_loc);
+ tiled_param_info->set_x(x_loc);
+
+ // Record the linear address for the current reduction.
+ const ReductionCodegenInfo* reduction_info =
+ dynamic_cast<const ReductionCodegenInfo*>(kernel_info);
+ Store(index[reduction_info->GetKeptDimensionEnum()],
+ reduction_info->GetCurrentOutputLinearIndexAddress());
+ if (!reduction_info->IsRowReduction()) {
+ llvm::Type* bool_ty = b_.getInt1Ty();
+ llvm::AllocaInst* output_inbound_addr =
+ reduction_info->GetCurrentOutputInboundAddress();
+ Store(llvm::ConstantInt::get(bool_ty, 1), output_inbound_addr);
+ }
+
+ InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
+ std::vector<std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
+ extra_output_gens;
+ GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
+ GetNestedComputer());
+ FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
+ &elem_emitter);
+ absl::Span<HloInstruction* const> output_instructions =
+ GetOutputInstructions(&reduce_or_tuple);
+ // Construct the ElementGenerator for each reduction and extra output in the
+ // the group of output instructions.
+ if (unnested_hlo->opcode() == HloOpcode::kFusion) {
+ fused_emitter.SetTiledParameterInfo(tiled_param_info);
+ TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
+
+ for (int i = 0, e = output_instructions.size(); i != e; ++i) {
+ const HloInstruction* inst = output_instructions[i];
+ ShapeIndex output_shape_index;
+ if (reduce_or_tuple->opcode() == HloOpcode::kTuple) {
+ output_shape_index = {i};
+ }
+ if (inst->opcode() == HloOpcode::kReduce) {
+ input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
+ } else {
+ extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
+ std::move(output_shape_index));
+ }
+ }
+ } else {
+ input_gens.push_back([&](const IrArray::Index& index) {
+ return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo)
+ .EmitReadArrayElement(index, &b_);
+ });
+ }
+
+ IrArray::Index input_index =
+ reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex(
+ index,
+ GetFirstReduceInstruction(output_instructions)->operand(0)->shape());
+ const AddressVector& partial_reduction_result_addresses =
+ reduction_info->GetPartialResultAddresses();
+ const AddressVector& reduction_input_addresses =
+ reduction_info->GetReductionInputAddresses();
+ const InlinedVector<HloComputation*, 1>& reducers =
+ reduction_info->GetReducers();
+
+ // Emit code to generate the input and perform the reduction computation for
+ // each reduction instruction.
+ for (int i = 0; i != reducers.size(); ++i) {
+ llvm::Value* const input_ir_value = input_gens[i](input_index).ValueOrDie();
+ Store(input_ir_value, reduction_input_addresses[i]);
+ TF_CHECK_OK(EmitCallToNestedComputation(
+ *reducers[i],
+ {partial_reduction_result_addresses[i], reduction_input_addresses[i]},
+ partial_reduction_result_addresses[i]));
+ }
+
+ // Emit code to generate the output for the non-reduction instructions in the
+ // fusion, if any.
+ TF_CHECK_OK(
+ EmitExtraOutputsForReduce(unnested_hlo, input_index, extra_output_gens));
+}
+
+// Emits a kernel for the hlo instruction using the given tiling scheme.
void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile,
const KernelCodegenInfo* kernel_info,
KernelSupportLibrary& ksl,
@@ -3520,7 +2864,6 @@
<< llvm_ir::DumpToString(*param_shmem_buffers[id]);
}
- CHECK_EQ(mapping_scheme->GetThreadsPerTile() % kWarpSize, 0);
LaunchDimensions launch_dimensions = LaunchDimensions(
mapping_scheme->GetNumberOfBlocks(), mapping_scheme->GetThreadsPerTile());
llvm::Type* index_ty = GetIndexTypeForKernel(
@@ -3549,6 +2892,7 @@
kernel_info->SetLaneId(
mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x
: nullptr);
+ kernel_info->SetIndexType(index_ty);
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
// Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
@@ -3573,29 +2917,31 @@
input_tile_origin.AddOffsetToDim(x, KernelMappingScheme::DimX, &b_)
.AddOffsetToDim(y, KernelMappingScheme::DimY, &b_);
- // Copy input parameter values to shared memory buffers:
- // tile[y, x] = input[index]
- // Note that tile_width and tile_height are flipped here because we are
- // reading a transposed tile.
- emit_tiled_elemental_code_with_bounds_check(
- input_index, "input", output_tile_bounds[2], output_tile_bounds[1],
- [&](const IrArray::Index& index, llvm::Value* y_loc,
- llvm::Value* x_loc) {
- for (int64 id : tiled_param_ids) {
- IrArray& input_in_logical_shape = param_in_reduced_shape_arrays[id];
- llvm::Value* shmem_buffer = param_shmem_buffers[id];
- // TODO(jlebar): Add AA metadata to this store. Tile buffers are
- // global variables, so LLVM can't infer much about it.
- Store(input_in_logical_shape.EmitReadArrayElement(index, &b_,
- "input_element"),
- GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc}));
- }
- });
-
// If shared memory transpose is needed, wait for all threads to reach this
// point, lest we copy a value from tile to output before the other thread
// copies it from input to tile. This is `__syncthreads` in CUDA.
if (!tiled_param_ids.empty()) {
+ // Copy input parameter values to shared memory buffers:
+ // tile[y, x] = input[index]
+ // Note that tile_width and tile_height are flipped here because we are
+ // reading a transposed tile.
+ emit_tiled_elemental_code_with_bounds_check(
+ input_index, "input", output_tile_bounds[2], output_tile_bounds[1],
+ [&](const IrArray::Index& index, llvm::Value* y_loc,
+ llvm::Value* x_loc) {
+ for (int64 id : tiled_param_ids) {
+ IrArray& input_in_logical_shape =
+ param_in_reduced_shape_arrays[id];
+ llvm::Value* shmem_buffer = param_shmem_buffers[id];
+ // TODO(jlebar): Add AA metadata to this store. Tile buffers are
+ // global variables, so LLVM can't infer much about it.
+ Store(input_in_logical_shape.EmitReadArrayElement(
+ index, &b_, "input_element"),
+ GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc}));
+ }
+ });
+
+ // Wait for all threads to reach this point using `__syncthreads` in CUDA.
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, &b_);
}
@@ -3615,6 +2961,7 @@
kernel_generator.GetTileElementGenerator()(unnested_hlo, index,
kernel_info, y_loc, x_loc);
});
+
// If a tile block contains multiple tiles and shared memory buffers are
// used, we need to wait for all threads to finish using the shared memory
// buffer for the current tile before we move on to process the next tile
@@ -3819,6 +3166,246 @@
return true;
}
+namespace {
+// Checks that the outputs of a fusion with reduction are consistent.
+Status AreFusedReductionOutputsConsistent(
+ absl::Span<HloInstruction* const> output_instructions,
+ const HloInstruction* first_reduce) {
+ for (const HloInstruction* inst : output_instructions) {
+ if (inst->opcode() == HloOpcode::kReduce) {
+ // Shapes, layouts and dimensions must be the same for all reduces
+ // inside of this fusion.
+ TF_RET_CHECK(ShapeUtil::Equal(first_reduce->shape(), inst->shape()));
+ TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(0)->shape(),
+ inst->operand(0)->shape()));
+ TF_RET_CHECK(ShapeUtil::Equal(first_reduce->operand(1)->shape(),
+ inst->operand(1)->shape()));
+ TF_RET_CHECK(first_reduce->dimensions() == inst->dimensions());
+ } else {
+ // For extra outputs we can relax shape equality to allow different
+ // types (with the same number of elements). Layouts still have to
+ // match.
+ TF_RET_CHECK(ShapeUtil::CompatibleIgnoringElementType(
+ first_reduce->operand(0)->shape(), inst->shape()));
+ TF_RET_CHECK(LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
+ inst->shape().layout()));
+ }
+ }
+ return Status::OK();
+}
+
+// Finds the dimensions to keep for the reduction, sorts and returns the
+// dimensions from minor to major.
+DimensionVector GetDimensionsToKeepMinorToMajor(
+ const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
+ DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0);
+ absl::c_iota(input_dims, 0);
+ DimensionVector input_dims_to_keep;
+ for (int input_dim : input_dims) {
+ auto it = absl::c_find_if(dims_to_reduce, [&](int64 dim_to_reduce) {
+ return dim_to_reduce == input_dim;
+ });
+ if (it == dims_to_reduce.end()) {
+ input_dims_to_keep.push_back(input_dim);
+ }
+ }
+
+ // Sort the dimensions to keep from minor to major.
+ absl::c_sort(input_dims_to_keep, [&input_shape](int64 dim_a, int64 dim_b) {
+ return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) <
+ PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b);
+ });
+
+ VLOG(10) << "dims to keep minor to major"
+ << absl::StrJoin(input_dims_to_keep, ",");
+ return input_dims_to_keep;
+}
+
+// Given the input shape and dimensions to reduce for the reduction to vector,
+// returns <num_reduced_major, num_kept, num_reduced_minor>:
+// num_kept: the number of elements in the contiguous dimensions to keep.
+// num_reduced_major: the number of elements in the dimensions to reduce that
+// are more major than the dimensions to keep.
+// num_reduced_minor: the number of elements in the dimensions to reduce that
+// are more minor than the dimensions to kept.
+std::tuple<int64, int64, int64> GetReductionToVectorDimensions(
+ const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
+ DimensionVector input_dims_to_keep_minor_to_major =
+ GetDimensionsToKeepMinorToMajor(input_shape, dims_to_reduce);
+ CHECK(LayoutUtil::AreDimensionsConsecutive(
+ input_shape.layout(), input_dims_to_keep_minor_to_major));
+ int num_reduced_major = 1, num_kept = 1, num_reduced_minor = 1;
+ if (input_dims_to_keep_minor_to_major.empty()) {
+ return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor);
+ }
+ DimensionVector input_dims(ShapeUtil::Rank(input_shape), 0);
+ absl::c_iota(input_dims, 0);
+ absl::Span<const int64> minor_to_major =
+ LayoutUtil::MinorToMajor(input_shape);
+ for (int input_dim : input_dims) {
+ int64 curr_dim_size = input_shape.dimensions(input_dim);
+ if (PositionInContainer(minor_to_major, input_dim) >
+ PositionInContainer(minor_to_major,
+ input_dims_to_keep_minor_to_major.back())) {
+ num_reduced_major *= curr_dim_size;
+ } else if (PositionInContainer(minor_to_major, input_dim) <
+ PositionInContainer(minor_to_major,
+ input_dims_to_keep_minor_to_major.front())) {
+ num_reduced_minor *= curr_dim_size;
+ } else {
+ num_kept *= curr_dim_size;
+ }
+ }
+
+ return std::make_tuple(num_reduced_major, num_kept, num_reduced_minor);
+}
+
+std::tuple<KernelMappingScheme, bool> ComputeMappingSchemeAndReductionKind(
+ const HloInstruction* first_reduce, llvm::IRBuilder<>* b) {
+ int64 depth = 1;
+ int64 height = 1;
+ int64 width = 1;
+ bool is_row_reduction = true;
+ int64 tile_size_x = 1;
+ int64 tile_size_y = 1;
+ int64 block_size_y = 1;
+ int64 block_size_z = 1;
+ int64 num_threads_x = 1;
+ int64 num_threads_y = 1;
+ const Shape& input_shape = first_reduce->operand(0)->shape();
+ int64 num_input_elems = ShapeUtil::ElementsIn(input_shape);
+ int64 num_output_elems = ShapeUtil::ElementsIn(first_reduce->shape());
+ int64 num_reduced_major, num_kept, num_reduced_minor;
+ std::tie(num_reduced_major, num_kept, num_reduced_minor) =
+ GetReductionToVectorDimensions(input_shape, first_reduce->dimensions());
+ CHECK_EQ(num_output_elems, num_kept);
+
+ if (num_kept == 1) {
+ // Scalar reduction is a special row reduction with depth = height = 1.
+ width = num_input_elems;
+ tile_size_x = kWarpSize * 16;
+ num_threads_x = kWarpSize;
+ } else if (num_reduced_minor == 1) {
+ // Column reduction reduces inputs with dimension [height, width], where
+ // width is the minor dimension, to dimension [width].
+ height = num_reduced_major;
+ width = num_kept;
+ is_row_reduction = false;
+ tile_size_x = std::min(kWarpSize, num_kept);
+ // The old Column reduction algorithm uses kTileHeight = 128. We choose
+ // tile_size_y * block_size_y = 128 to match the value of kTileHeight. Using
+ // a non-trivial block_size_y here is a way to avoid unrolling all the 128
+ // iterations.
+ tile_size_y = 32;
+ block_size_y = 4;
+ num_threads_x = tile_size_x;
+ } else {
+ // Row reduction reduces inputs with dimension [depth, height, width],
+ // where width is the most minor dimension, to dimension [height] .
+ depth = num_reduced_major;
+ height = num_kept;
+ width = num_reduced_minor;
+ num_threads_x = kWarpSize;
+ if (width % (kWarpSize * 64) == 0) {
+ tile_size_x = kWarpSize * 64;
+ } else {
+ tile_size_x = kWarpSize * 8;
+ block_size_z = 8;
+ while (depth % block_size_z != 0) {
+ block_size_z -= 1;
+ }
+ }
+ }
+ DCHECK_EQ(depth * height * width, num_input_elems);
+ VLOG(10) << "is_row_reduction " << is_row_reduction << depth << " " << height
+ << " " << width;
+
+ DimensionVector dims_in_elem{depth, height, width};
+ DimensionVector req_block_sizes{block_size_z, block_size_y, 1};
+ llvm_ir::KernelMappingScheme mapping_scheme(dims_in_elem, tile_size_y,
+ tile_size_x, req_block_sizes,
+ num_threads_y, num_threads_x, b);
+ return std::make_tuple(mapping_scheme, is_row_reduction);
+}
+
+} // namespace
+
+Status IrEmitterUnnested::EmitReductionToVector(HloInstruction* unnested_hlo) {
+ VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
+
+ HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
+ ? unnested_hlo->fused_expression_root()
+ : unnested_hlo;
+ absl::Span<HloInstruction* const> output_instructions =
+ GetOutputInstructions(&reduce_or_tuple);
+ const HloInstruction* first_reduce =
+ GetFirstReduceInstruction(output_instructions);
+
+ if (output_instructions.size() > 1) {
+ TF_RETURN_IF_ERROR(
+ AreFusedReductionOutputsConsistent(output_instructions, first_reduce));
+ }
+
+ // Build an initializer thunk to initialize each reduction output.
+ std::vector<std::unique_ptr<Thunk>> thunks;
+ for (int i = 0, e = output_instructions.size(); i != e; ++i) {
+ if (output_instructions[i]->opcode() != HloOpcode::kReduce) {
+ continue;
+ }
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Thunk> initializer_thunk,
+ BuildInitializerThunk(unnested_hlo,
+ (output_instructions[i] == reduce_or_tuple)
+ ? ShapeIndex()
+ : ShapeIndex({i})));
+ thunks.push_back(std::move(initializer_thunk));
+ }
+
+ // Build a kernel thunk to compute all the outputs.
+ std::unique_ptr<KernelThunk> kernel_thunk =
+ BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
+
+ const Shape& input_shape = first_reduce->operand(0)->shape();
+ // The layout of a reduction input is either set by LayoutAssignment for
+ // unnested kReduce or by InstructionFusion for fused kReduce.
+ CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
+ "doesn't set the input layout of "
+ << first_reduce->ToString();
+
+ bool is_row_reduction;
+ llvm_ir::KernelMappingScheme mapping_scheme;
+ std::tie(mapping_scheme, is_row_reduction) =
+ ComputeMappingSchemeAndReductionKind(first_reduce, &b_);
+ ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction);
+ KernelCodeGenerator kernel_generator(
+ /*tile_element_generator=*/
+ [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
+ const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
+ llvm::Value* x_loc) {
+ EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc);
+ },
+ /*block_prologue_generator=*/
+ [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
+ EmitPrologueForReduction(hlo, kernel_info);
+ },
+ /*block_epilogue_generator*/
+ [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
+ EmitEpilogueForReduction(hlo, kernel_info);
+ });
+
+ LaunchDimensions launch_dimensions =
+ EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info);
+ UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
+ ir_emitter_context_->llvm_module());
+
+ thunks.push_back(std::move(kernel_thunk));
+ std::unique_ptr<SequentialThunk> sequential_thunk =
+ absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
+ AddThunkToThunkSequence(std::move(sequential_thunk));
+
+ return Status::OK();
+}
+
Status IrEmitterUnnested::EmitConstantGlobals() {
for (const BufferAllocation& allocation :
ir_emitter_context_->buffer_assignment().Allocations()) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 97a1e10..85a0e53 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -16,6 +16,7 @@
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
+#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
@@ -68,9 +69,12 @@
explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme)
: mapping_scheme_(mapping_scheme),
tiled_param_info_(nullptr),
- lane_id_(nullptr) {}
+ lane_id_(nullptr),
+ index_ty_(nullptr) {}
+ virtual ~KernelCodegenInfo() {}
void SetLaneId(llvm::Value* v) { lane_id_ = v; }
+ void SetIndexType(llvm::Type* t) { index_ty_ = t; }
void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) {
CHECK_EQ(tiled_param_info_, nullptr);
tiled_param_info_ = tiled_param_info;
@@ -83,11 +87,13 @@
llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const {
return tiled_param_info_;
}
+ llvm::Type* GetIndexType() const { return index_ty_; }
private:
llvm_ir::KernelMappingScheme* mapping_scheme_;
llvm_ir::TiledParameterInfo* tiled_param_info_;
llvm::Value* lane_id_;
+ llvm::Type* index_ty_;
};
// A function object to prepare for the code generation for a tile block.
@@ -171,7 +177,7 @@
Status HandleSort(HloInstruction* sort) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override;
- Status HandleAfterAll(HloInstruction* gen_token) override;
+ Status HandleAfterAll(HloInstruction* after_all) override;
Status EmitTargetElementLoop(
const HloInstruction& hlo,
@@ -200,82 +206,14 @@
// Helper for writing extra outputs from inside a reduce kernel.
Status EmitExtraOutputsForReduce(
- const HloInstruction* reduce, const llvm_ir::IrArray::Index& index,
+ const HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index,
absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens);
- // EmitColumnReduction and EmitRowReduction emit code for column and row
- // reduction of a matrix and/or 3D tensor. Row and column reduction have
- // different memory access pattern, so for performance their implementations
- // are significantly different.
+ // Generates code for reduction to contiguous dimensions.
//
- // Emits code that reduces a matrix of shape [height x width] to a vector of
- // [width]. Other parameters have the same meaning as those of
- // `EmitReductionToVector`. Note that input shape might not be
- // [height x width], but can be bitcast to [height x width] with "height"
- // being the major dimension.
- Status EmitColumnReduction(
- KernelThunk* kernel_thunk, int64 height, int64 width,
- HloInstruction* reduce, const Shape& input_shape,
- absl::Span<const llvm_ir::ElementGenerator> input_gens,
- absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
- absl::Span<HloComputation* const> reducers,
- absl::Span<const ShapeIndex> reduce_output_shapes,
- absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens);
-
- // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
- // vector of shape [height]. Other parameters have the same meaning as those
- // of `EmitReductionToVector`. Note that input shape might not be
- // [depth x height x width], but can be bitcast to [depth x height x width]
- // with "depth" being the most major dimension.
- Status EmitRowReduction(
- KernelThunk* kernel_thunk, int64 depth, int64 height, int64 width,
- HloInstruction* reduce, const Shape& input_shape,
- absl::Span<const llvm_ir::ElementGenerator> input_gens,
- absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
- absl::Span<HloComputation* const> reducers,
- absl::Span<const ShapeIndex> reduce_output_shapes,
- absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens);
-
- // Emits code that reduces a tensor of arbitrary rank to a scalar.
- Status EmitReductionToScalar(
- KernelThunk* kernel_thunk, HloInstruction* reduce,
- const Shape& input_shape,
- absl::Span<const llvm_ir::ElementGenerator> input_gens,
- absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
- absl::Span<HloComputation* const> reducers,
- absl::Span<const ShapeIndex> reduce_output_shapes,
- absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens);
-
- // Figures out whether `reduce` is a row or column reduction, and which
- // dimensions to reduce, and calls either `EmitRowReduction` or
- // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the
- // input array, which is the operand of the Reduce instruction if unfused or
- // of the Fusion instruction if fused. `input_gen` and `init_value_gen`
- // generate elements of the input and the initial value. Other parameters mean
- // the same as for `HandleReduce`.
- //
- // Multiple reduces can be emitted in the same loop, assuming they have the
- // same input and output shapes, and the same reduce dimensions.
- //
- // extra_output_gens can contain extra generators for intermediate outputs.
- // These must have the same shape as the reduce input as they are computed
- // when the reduce inputs are being read.
- //
- // Prerequisite: `IsReductionToVector(*reduce)`
- Status EmitReductionToVector(
- KernelThunk* kernel_thunk, HloInstruction* reduce,
- const Shape& input_shape,
- absl::Span<const llvm_ir::ElementGenerator> input_gens,
- absl::Span<const llvm_ir::ElementGenerator> init_value_gens,
- absl::Span<const int64> dimensions_to_reduce,
- absl::Span<HloComputation* const> reducers,
- absl::Span<const ShapeIndex> reduce_output_shapes,
- absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
- extra_output_gens);
+ // Prerequisite: `IsReductionToVector(*unnested_hlo)`
+ Status EmitReductionToVector(HloInstruction* unnested_hlo);
// Emits code for an in-place scatter, modifying `thunk`s launch dimensions in
// the process. `scatter` may be fused, scatter indices are taken from
@@ -314,6 +252,29 @@
const llvm_ir::IrArray::Index& index,
const KernelCodegenInfo* kernel_info,
llvm::Value* y_loc, llvm::Value* x_loc);
+ // Emits code to process a tensor element in a tile for the given input hlo
+ // that is either a unnested kReduce or a kInput fusion.
+ void EmitTileElementForReduction(HloInstruction* unnested_hlo,
+ const llvm_ir::IrArray::Index& index,
+ const KernelCodegenInfo* kernel_info,
+ llvm::Value* y_loc, llvm::Value* x_loc);
+ // Prepares for the code generation for a tile block of a reduction kernel.
+ void EmitPrologueForReduction(HloInstruction* unnested_hlo,
+ KernelCodegenInfo* kernel_info);
+ void EmitPrologueForOneReduction(HloInstruction* unnested_hlo,
+ HloInstruction* reduce_inst, int reduce_idx,
+ KernelCodegenInfo* kernel_info,
+ GpuElementalIrEmitter* elemental_emitter,
+ ShapeIndex output_shape_index);
+ // Wraps up the code generation for a tile block of a reduction kernel.
+ void EmitEpilogueForReduction(HloInstruction* unnested_hlo,
+ KernelCodegenInfo* kernel_info);
+ // For each reducer, emits the shuffle-down loop to accumulate the partial
+ // result to the global result.
+ void EmitFullWarpShuffleDownLoopForAllReduces(
+ const absl::InlinedVector<HloComputation*, 1>& reducers,
+ const absl::InlinedVector<llvm::AllocaInst*, 1>&
+ partial_result_addresses);
// Generates the IrArray for each input of an hlo and returns a vector that
// constains such IrArrays.
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 913d4c3..c62c935 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -205,7 +205,8 @@
repeated HloInstructionProto instructions = 2;
// The program shape (with layout) of this computation.
- xla.ProgramShape program_shape = 4;
+
+ xla.ProgramShapeProto program_shape = 4;
// The id of this computation.
int64 id = 5;
@@ -297,7 +298,7 @@
repeated HloComputationProto computations = 3;
// The host program shape (with layout) of the entry computation.
- xla.ProgramShape host_program_shape = 4;
+ xla.ProgramShapeProto host_program_shape = 4;
// The id of this module.
int64 id = 5;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 65bd251..d06c220 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -499,7 +499,7 @@
proto.add_instructions()->Swap(&instruction_proto);
}
proto.set_root_id(root_instruction()->unique_id());
- *proto.mutable_program_shape() = ComputeProgramShape();
+ *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index fdfb38b..df7d382 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -419,6 +419,21 @@
}
Status HloCostAnalysis::HandleAfterAll(const HloInstruction*) {
+ // This instruction is used to enforce ordering at compile time. No code is
+ // emitted.
+ current_should_compute_bottleneck_time_ = false;
+ current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleAddDependency(
+ const HloInstruction* add_dependency) {
+ // This instruction is used to enforce ordering at compile time. No code is
+ // emitted.
+ current_should_compute_bottleneck_time_ = false;
+ current_properties_[kBytesAccessedKey] = 0;
+ current_properties_[kOptimalSecondsKey] = 0;
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 8ced9d7..3398311 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -101,6 +101,7 @@
Status HandleBroadcast(const HloInstruction* broadcast) override;
Status HandlePad(const HloInstruction* pad) override;
Status HandleReshape(const HloInstruction* reshape) override;
+ Status HandleAddDependency(const HloInstruction* add_dependency) override;
Status HandleAfterAll(const HloInstruction* token) override;
Status HandleTranspose(const HloInstruction* transpose) override;
Status HandleWhile(const HloInstruction* xla_while) override;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 5dcf6bc..3ed3d3c 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -466,6 +466,21 @@
return changed;
}
+bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
+ HloInstruction* add_dependency) {
+ // AddDependency just forwards the value of its zero-th operand.
+ CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
+ const InstructionValueSet& operand_set =
+ GetInstructionValueSet(add_dependency->operand(0));
+ InstructionValueSet& add_dependency_set =
+ GetInstructionValueSet(add_dependency);
+ if (operand_set != add_dependency_set) {
+ add_dependency_set = operand_set;
+ return true;
+ }
+ return false;
+}
+
bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
bool changed = false;
@@ -622,6 +637,8 @@
HloInstruction* instruction) {
// Recompute from operands.
switch (instruction->opcode()) {
+ case HloOpcode::kAddDependency:
+ return UpdateAddDependencyValueSet(instruction);
case HloOpcode::kBitcast:
return UpdateBitcastValueSet(instruction);
case HloOpcode::kDomain:
@@ -795,6 +812,7 @@
define_all_values();
}
break;
+ case HloOpcode::kAddDependency:
case HloOpcode::kWhile:
case HloOpcode::kCall:
case HloOpcode::kConditional:
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index abac398..ece17fc 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -193,6 +193,7 @@
bool UpdateSendValueSet(HloInstruction* send);
bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while);
+ bool UpdateAddDependencyValueSet(HloInstruction* add_dependency);
// Propagate the dataflow through the module.
void Propagate();
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index e8eb706..f7a1f19 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1877,6 +1877,30 @@
}
}
+TEST_P(HloDataflowAnalysisTest, AddDependency) {
+ string module_string = R"(
+HloModule AddDependency
+ENTRY %AddDependency (p: f32[3]) -> f32[3] {
+ %p = f32[3] parameter(0)
+ %token = token[] after-all()
+ ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(module_string, GetModuleConfigForTest()));
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
+ HloDataflowAnalysis::Run(*module));
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency);
+
+ // The after-all and parameter should define a value. Add-dependency should
+ // not.
+ EXPECT_EQ(analysis->values().size(), 2);
+ EXPECT_FALSE(analysis->ValueIsDefinedAt(root));
+}
+
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
HloDataflowAnalysisTest,
::testing::Values(false, true));
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 7fcafaf..51a3fba 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1046,8 +1046,15 @@
return Status::OK();
}
-Status HloEvaluator::HandleAfterAll(HloInstruction* token) {
- evaluated_[token] = LiteralUtil::CreateToken();
+Status HloEvaluator::HandleAfterAll(HloInstruction* after_all) {
+ evaluated_[after_all] = LiteralUtil::CreateToken();
+ return Status::OK();
+}
+
+Status HloEvaluator::HandleAddDependency(HloInstruction* add_dependency) {
+ // AddDedendency just forwards its zero-th operand.
+ evaluated_[add_dependency] =
+ GetEvaluatedLiteralFor(add_dependency->operand(0)).Clone();
return Status::OK();
}
@@ -1279,10 +1286,10 @@
key_value_vector.push_back(
std::make_pair(keys_data[i], values_data[i]));
}
- std::sort(key_value_vector.begin(), key_value_vector.end(),
- [](const kv_pair& a, const kv_pair& b) {
- return SafeLess<KeyType>(a.first, b.first);
- });
+ std::stable_sort(key_value_vector.begin(), key_value_vector.end(),
+ [](const kv_pair& a, const kv_pair& b) {
+ return SafeLess<KeyType>(a.first, b.first);
+ });
std::vector<KeyType> result_keys;
// We use a InlinedVector here because we need to convert it to an
// absl::Span later, and this would not work with std::vector<bool>.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index d751f40..d847900 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -180,7 +180,9 @@
Status HandleBroadcast(HloInstruction* broadcast) override;
- Status HandleAfterAll(HloInstruction* token) override;
+ Status HandleAfterAll(HloInstruction* after_all) override;
+
+ Status HandleAddDependency(HloInstruction* add_dependency) override;
Status HandleSort(HloInstruction* sort) override;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 332fa87..b87fc3e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -1553,10 +1553,10 @@
const auto& row_data = row_to_sort.data<NativeT>();
std::vector<NativeT> result_data(row_data.begin(), row_data.end());
- std::sort(result_data.begin(), result_data.end(),
- [](const NativeT& a, const NativeT& b) {
- return SafeLess<NativeT>(a, b);
- });
+ std::stable_sort(result_data.begin(), result_data.end(),
+ [](const NativeT& a, const NativeT& b) {
+ return SafeLess<NativeT>(a, b);
+ });
Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(),
{sort_dim_elements}));
sorted_row.PopulateR1(absl::Span<const NativeT>(result_data));
@@ -2543,12 +2543,14 @@
template <typename NativeT,
typename std::enable_if<
- std::is_same<NativeT, float>::value ||
- std::is_same<NativeT, int32>::value ||
- std::is_same<NativeT, uint32>::value>::type* = nullptr>
+ std::is_integral<NativeT>::value ||
+ std::is_floating_point<NativeT>::value>::type* = nullptr>
Status HandleIota(HloInstruction* instruction) {
auto* iota = Cast<HloIotaInstruction>(instruction);
- std::vector<NativeT> data(iota->shape().dimensions(iota->iota_dimension()));
+ // Avoid using std::vector since std::vector<bool> does not convert to
+ // absl::Span<bool>.
+ absl::InlinedVector<NativeT, 1> data(
+ iota->shape().dimensions(iota->iota_dimension()));
std::iota(data.begin(), data.end(), 0);
auto result = LiteralUtil::CreateR1<NativeT>(data);
@@ -2565,9 +2567,8 @@
}
template <typename NativeT,
typename std::enable_if<
- !(std::is_same<NativeT, float>::value ||
- std::is_same<NativeT, int32>::value ||
- std::is_same<NativeT, uint32>::value)>::type* = nullptr>
+ !(std::is_integral<NativeT>::value ||
+ std::is_floating_point<NativeT>::value)>::type* = nullptr>
Status HandleIota(HloInstruction* iota) {
return InvalidArgument("Unsupported type for iota");
}
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 05cc159..7e9e94c 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -987,6 +987,7 @@
case HloOpcode::kGetTupleElement:
case HloOpcode::kTrace:
case HloOpcode::kAfterAll:
+ case HloOpcode::kAddDependency:
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index cd95052..1e3881c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -855,6 +855,16 @@
new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
}
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateAddDependency(HloInstruction* data_operand,
+ HloInstruction* token_operand) {
+ auto instruction = absl::WrapUnique(
+ new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
+ instruction->AppendOperand(data_operand);
+ instruction->AppendOperand(token_operand);
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
@@ -1394,6 +1404,10 @@
clone = CreateAfterAll(new_operands);
}
break;
+ case HloOpcode::kAddDependency:
+ CHECK_EQ(new_operands.size(), 2);
+ clone = CreateAddDependency(new_operands[0], new_operands[1]);
+ break;
}
// SetupDerivedInstruction will setup the precision_config_ field.
SetupDerivedInstruction(clone.get());
@@ -1680,6 +1694,7 @@
// This opcode has complex or special behavior so just return false.
case HloOpcode::kAfterAll:
+ case HloOpcode::kAddDependency:
return false;
// Remaining instructions with special values.
@@ -2467,6 +2482,8 @@
return visitor->HandleDomain(this);
case HloOpcode::kAfterAll:
return visitor->HandleAfterAll(this);
+ case HloOpcode::kAddDependency:
+ return visitor->HandleAddDependency(this);
case HloOpcode::kIota:
return visitor->HandleIota(this);
case HloOpcode::kGetDimensionSize:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 95ad292..87748a7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -770,6 +770,9 @@
static std::unique_ptr<HloInstruction> CreateGetDimensionSize(
const Shape& shape, HloInstruction* operand, int64 dimension);
+ static std::unique_ptr<HloInstruction> CreateAddDependency(
+ HloInstruction* data_operand, HloInstruction* token_operand);
+
// Returns the opcode for this instruction.
HloOpcode opcode() const { return opcode_; }
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index 5269cad..d28e79d 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -237,8 +237,4 @@
*os << (inst ? inst->ToString() : "nullptr");
}
-void PrintTo(HloInstruction* inst, ::std::ostream* os) {
- PrintTo(const_cast<const HloInstruction*>(inst), os);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 170ec93..235efb1 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -385,7 +385,6 @@
// Tell GMock to print HloInstruction* by value, so error messages are nice.
// Has to be in the same namespace as 'HloInstruction'.
void PrintTo(const HloInstruction* inst, ::std::ostream* os);
-void PrintTo(HloInstruction* inst, ::std::ostream* os);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 59f4447..a01853f 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -240,7 +240,7 @@
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
}
*proto.mutable_host_program_shape() =
- entry_computation_layout().ComputeProgramShape();
+ entry_computation_layout().ComputeProgramShape().ToProto();
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
*proto.mutable_dynamic_parameter_binding() =
dynamic_parameter_binding().ToProto();
@@ -371,7 +371,7 @@
<< "No program shape found in the proto";
const auto& program_shape = module.host_program_shape();
- HloModuleConfig module_config(program_shape);
+ HloModuleConfig module_config(ProgramShape{program_shape});
module_config.set_debug_options(debug_options);
// The module config is constructed with default layouts regardless of what is
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 70c7d70..127cfd1 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -47,6 +47,8 @@
#define HLO_OPCODE_LIST(V) \
V(kAbs, "abs") \
V(kAdd, "add") \
+ V(kAddDependency, "add-dependency") \
+ V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
V(kAllToAll, "all-to-all") \
V(kAtan2, "atan2") \
V(kBatchNormGrad, "batch-norm-grad") \
@@ -84,7 +86,6 @@
V(kGather, "gather") \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
V(kGetDimensionSize, "get-dimension-size") \
- V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
V(kImag, "imag") \
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 4bf287a..9b5bb5d 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -850,6 +850,15 @@
}
break;
}
+ case HloOpcode::kAddDependency: {
+ if (!ParseOperands(&operands, /*expected_size=*/2) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateAddDependency(operands[0], operands[1]));
+ break;
+ }
case HloOpcode::kSort: {
optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 88682e5..f13f750 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1241,7 +1241,38 @@
}
)"
+ },
+// AfterAll with multiple operands
+{
+"AfterAllWithMultipleOperands",
+R"(HloModule AfterAllWithMultipleOperands
+
+ENTRY AfterAllWithMultipleOperands {
+ p0 = f32[] parameter(0)
+ token0 = token[] after-all()
+ token1 = token[] after-all()
+ ROOT after-all = token[] after-all(p0, token0, token1)
}
+
+)"
+},
+// AddDependency
+// A dependency chain is created from 'neg' to 'exp' using tokens.
+{
+"AddDependency",
+R"(HloModule AddDependency
+
+ENTRY AddDependency {
+ p = f32[] parameter(0)
+ neg = f32[] negate(p)
+ token = token[] after-all(neg)
+ p_after_token = f32[] add-dependency(p, token)
+ exp = f32[] exponential(p_after_token)
+ ROOT sum = f32[] add(neg, exp)
+}
+
+)"
+},
});
// clang-format on
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
index 88329c8..f506130 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc
@@ -253,7 +253,7 @@
instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
for (HloInstruction* user : instruction->users()) {
if (user->opcode() == HloOpcode::kDomain &&
- domain.exit_domains.count(const_cast<HloInstruction*>(user)) > 0) {
+ domain.exit_domains.count(user) > 0) {
// If a user is a domain and it is registered in the domain exits, then
// the instruction sharding is taken directly from the domain, and no
// further users need to be visited.
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 60d8a51..77db7b0 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -753,7 +753,13 @@
for (const HloInstruction* operand : token->operands()) {
operand_shapes.push_back(&operand->shape());
}
- return CheckShape(token, ShapeInference::InferAfterAllShape(operand_shapes));
+ return CheckShape(token, ShapeUtil::MakeTokenShape());
+}
+
+Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(add_dependency, 2));
+ TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1));
+ return CheckShape(add_dependency, add_dependency->operand(0)->shape());
}
Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
@@ -1373,9 +1379,8 @@
const Layout& operand_layout = operand_shape.layout();
TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
<< "Instruction shouldn't change layouts "
- << instruction->ToString() << " From "
- << ShapeUtil::HumanString(result_shape) << " To "
- << ShapeUtil::HumanString(operand_shape);
+ << instruction->ToString() << " From " << result_shape << " To "
+ << operand_shape;
}
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 9fbfd6a..e4d0c3d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -95,6 +95,7 @@
Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* token) override;
Status HandleGetDimensionSize(HloInstruction* get_size) override;
+ Status HandleAddDependency(HloInstruction* add_dependency) override;
Status FinishVisit(HloInstruction*) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 7f2d7e7..2297edc 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -103,7 +103,6 @@
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSlice:
case HloOpcode::kSubtract:
- case HloOpcode::kAfterAll:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
case HloOpcode::kTupleSelect:
@@ -116,7 +115,10 @@
case HloOpcode::kSin:
return ShapeUtil::ElementIsComplex(instruction.shape());
- // Expensive instructions.
+ // Expensive instructions or unusual instructions for which fusion is
+ // nonsensical.
+ case HloOpcode::kAddDependency:
+ case HloOpcode::kAfterAll:
case HloOpcode::kAtan2:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kBatchNormInference:
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc
index 4fb67bd..e3e5fa7 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executor.cc
@@ -78,9 +78,14 @@
return port::Status::OK();
}
-bool XlaInterpreterExecutor::HostCallback(Stream *stream,
- std::function<void()> callback) {
- AsExecutorStream(stream)->EnqueueTask(callback);
+bool XlaInterpreterExecutor::HostCallback(
+ Stream *stream, std::function<port::Status()> callback) {
+ AsExecutorStream(stream)->EnqueueTask([callback]() {
+ port::Status s = callback();
+ if (!s.ok()) {
+ LOG(WARNING) << "Host callback failed: " << s;
+ }
+ });
return true;
}
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index fbb9945..400c305 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -125,7 +125,8 @@
return port::Status{port::error::UNIMPLEMENTED, ""};
}
- bool HostCallback(Stream *stream, std::function<void()> callback) override;
+ bool HostCallback(Stream *stream,
+ std::function<port::Status()> callback) override;
port::Status AllocateEvent(Event *event) override {
return port::Status{port::error::UNIMPLEMENTED, ""};
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index a904119..eddef85 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -2000,6 +2000,7 @@
switch (instruction->opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kAdd:
+ case HloOpcode::kAddDependency:
case HloOpcode::kAnd:
case HloOpcode::kAtan2:
case HloOpcode::kBitcastConvert:
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
index c26711e..1aa85eb 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc
@@ -120,7 +120,7 @@
absl::Span<const int64> req_block_sizes, int64 num_threads_y,
int64 num_threads_x, llvm::IRBuilder<>* b)
: b_(b),
- dims_in_elems_(dims_in_elems),
+ dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()),
tile_sizes_{1, tile_size_y, tile_size_x},
num_threads_x_(num_threads_x),
num_threads_y_(num_threads_y) {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
index 06002d5..7277aea 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h
@@ -90,15 +90,16 @@
enum { DimZ = 0, DimY, DimX, DimTot };
public:
+ KernelMappingScheme() {}
// dims_in_elems: the normalized tensor dimensions.
// req_block_sizes: the requested block size in number of tiles for each
// dimension. The actual block size is set to min(req_block_size,
// dims_in_number_of_blocks).
- explicit KernelMappingScheme(absl::Span<const int64> dims_in_elems,
- int64 tile_size_y, int64 tile_size_x,
- absl::Span<const int64> req_block_sizes,
- int64 num_threads_y, int64 num_threads_x,
- llvm::IRBuilder<>* b);
+ KernelMappingScheme(absl::Span<const int64> dims_in_elems, int64 tile_size_y,
+ int64 tile_size_x,
+ absl::Span<const int64> req_block_sizes,
+ int64 num_threads_y, int64 num_threads_x,
+ llvm::IRBuilder<>* b);
absl::Span<const int64> GetDimensionsInElements() const {
return dims_in_elems_;
@@ -133,6 +134,10 @@
}
absl::Span<const int64> GetBlockSizes() const { return block_sizes_; }
+ int64 GetTileBlockSizeForDimension(int d) const {
+ DCHECK(d >= DimZ && d <= DimX);
+ return dims_in_blocks_[d];
+ }
int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; }
int64 GetNumberOfThreadsForDimensionY() const { return num_threads_y_; }
@@ -163,7 +168,7 @@
private:
llvm::IRBuilder<>* b_;
// The number of elements in each dimension.
- absl::Span<const int64> dims_in_elems_;
+ std::vector<int64> dims_in_elems_;
// The number of elements for each dimension of a tile.
std::vector<int64> tile_sizes_;
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
index fd16af6..e22c217 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc
@@ -47,7 +47,8 @@
// Adds the inner comparison loop body where we compare elements.
void EmitCompareLoopBody(
int64 iteration_bound, PrimitiveType key_type, int64 num_values,
- llvm::Value* element_pair_index, int64 xor_mask, llvm::Type* index_type,
+ int64 iota_values_parameter_index, llvm::Value* element_pair_index,
+ int64 xor_mask, llvm::Type* index_type,
std::function<llvm::Value*(int64 operand, llvm::Value* index)> read_element,
std::function<void(int64 operand, llvm::Value* index, llvm::Value* value)>
write_element,
@@ -139,34 +140,42 @@
is_signed_comparison = false;
}
// If key2 < key1
- ksl.IfReturnVoid(
- "is_smaller_than",
+ auto is_smaller_than =
b->CreateICmp(is_signed_comparison ? llvm::ICmpInst::ICMP_SLT
: llvm::ICmpInst::ICMP_ULT,
- compare_key2, compare_key1),
- [&]() {
- // Swap key1 with key2.
- write_element(0, current_keys_index, key2);
- write_element(0, compare_keys_index, key1);
- for (int64 i = 1; i <= num_values; ++i) {
- // Also swap the values.
- auto value1 = read_element(i, current_keys_index);
- auto value2 = read_element(i, compare_keys_index);
- write_element(i, current_keys_index, value2);
- write_element(i, compare_keys_index, value1);
- }
- });
+ compare_key2, compare_key1);
+ if (iota_values_parameter_index >= 0) {
+ auto keys_equal = b->CreateICmpEQ(compare_key1, compare_key2);
+ auto key_index1 =
+ read_element(iota_values_parameter_index, current_keys_index);
+ auto key_index2 =
+ read_element(iota_values_parameter_index, compare_keys_index);
+ auto index_is_smaller_than =
+ b->CreateICmp(llvm::ICmpInst::ICMP_ULT, key_index2, key_index1);
+ is_smaller_than = b->CreateOr(
+ is_smaller_than, b->CreateAnd(keys_equal, index_is_smaller_than));
+ }
+ ksl.IfReturnVoid("is_smaller_than", is_smaller_than, [&]() {
+ // Swap key1 with key2.
+ write_element(0, current_keys_index, key2);
+ write_element(0, compare_keys_index, key1);
+ for (int64 i = 1; i <= num_values; ++i) {
+ // Also swap the values.
+ auto value1 = read_element(i, current_keys_index);
+ auto value2 = read_element(i, compare_keys_index);
+ write_element(i, current_keys_index, value2);
+ write_element(i, compare_keys_index, value1);
+ }
+ });
});
}
-void EmitTiledCompareLoop(const IrArray::Index& tiled_keys_index,
- int64 dimension_to_sort,
- int64 dimension_to_sort_bound,
- PrimitiveType keys_type,
- absl::Span<const int64> xor_masks,
- const std::vector<IrArray>& params,
- const std::vector<llvm::Value*>& param_shmem_buffers,
- int64 tile_size, llvm::IRBuilder<>* b) {
+void EmitTiledCompareLoop(
+ const IrArray::Index& tiled_keys_index, int64 dimension_to_sort,
+ int64 dimension_to_sort_bound, PrimitiveType keys_type,
+ absl::Span<const int64> xor_masks, const std::vector<IrArray>& params,
+ const std::vector<llvm::Value*>& param_shmem_buffers,
+ int64 iota_values_parameter_index, int64 tile_size, llvm::IRBuilder<>* b) {
KernelSupportLibrary ksl(b);
llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b);
@@ -253,20 +262,22 @@
RoundDownToNearest(dimension_to_sort_bound, tile_size))),
[&]() {
EmitCompareLoopBody(dimension_to_sort_bound % tile_size, keys_type,
- params.size() - 1, element_pair_index, xor_mask,
+ params.size() - 1, iota_values_parameter_index,
+ element_pair_index, xor_mask,
tiled_keys_index.GetType(), read_element,
write_element, b);
},
[&]() {
- EmitCompareLoopBody(
- tile_size, keys_type, params.size() - 1, element_pair_index,
- xor_mask, tiled_keys_index.GetType(), read_element,
- write_element, b, /*needs_bounds_checks=*/false);
+ EmitCompareLoopBody(tile_size, keys_type, params.size() - 1,
+ iota_values_parameter_index, element_pair_index,
+ xor_mask, tiled_keys_index.GetType(),
+ read_element, write_element, b,
+ /*needs_bounds_checks=*/false);
});
} else {
EmitCompareLoopBody(tile_size, keys_type, params.size() - 1,
- element_pair_index, xor_mask,
- tiled_keys_index.GetType(), read_element,
+ iota_values_parameter_index, element_pair_index,
+ xor_mask, tiled_keys_index.GetType(), read_element,
write_element, b, /*needs_bounds_checks=*/false);
}
// Wait until all comparisons have happened.
@@ -296,6 +307,7 @@
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
const std::vector<IrArray>& values_arrays,
+ int64 iota_values_parameter_index,
absl::string_view name,
absl::Span<const int64> xor_masks, llvm::IRBuilder<>* b,
const gpu::LaunchDimensions& launch_dimensions,
@@ -367,8 +379,8 @@
if (xor_masks.size() > 1) {
EmitTiledCompareLoop(keys_index, dimension_to_sort,
dimension_to_sort_bound, keys_shape.element_type(),
- xor_masks, params, param_shmem_buffers, tile_size,
- b);
+ xor_masks, params, param_shmem_buffers,
+ iota_values_parameter_index, tile_size, b);
} else {
auto read_element = [&](int64 operand, llvm::Value* index) {
keys_index[dimension_to_sort] = index;
@@ -380,9 +392,10 @@
params[operand].EmitWriteArrayElement(keys_index, value, b);
};
EmitCompareLoopBody(dimension_to_sort_bound, keys_shape.element_type(),
- values_arrays.size(), tiles_index[rank - 1],
- xor_masks[0], tiles_index.GetType(), read_element,
- write_element, b);
+ values_arrays.size(), iota_values_parameter_index,
+ tiles_index[rank - 1], xor_masks[0],
+ tiles_index.GetType(), read_element, write_element,
+ b);
}
return Status::OK();
};
diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
index 556a217..685f938 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h
@@ -31,9 +31,12 @@
// Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort'
// dimension of 'keys_array'. All other dimensions are kept as-is. This
// implements the inner loop of BitonicSort. It is assumed that 'xor_masks'
-// contains only powers of 2, or values 2^k - 1 (k > 0).
+// contains only powers of 2, or values 2^k - 1 (k > 0). If
+// 'iota_values_parameter_index' is >= 0, it points at a 'values_arrays' operand
+// that is a iota and can be used to make the sorting stable.
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
const std::vector<IrArray>& values_arrays,
+ int64 iota_values_parameter_index,
absl::string_view name,
absl::Span<const int64> xor_masks, llvm::IRBuilder<>* b,
const gpu::LaunchDimensions& launch_dimensions,
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 2180ac8..ddc8691 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -96,28 +96,8 @@
const ExecutableBuildOptions& build_options,
const ProgramShape* program_shape) {
ExecutionOptions execution_options = CreateDefaultExecutionOptions();
- if (build_options.hlo_profile().has_value()) {
- execution_options.mutable_debug_options()->set_xla_hlo_profile(
- *build_options.hlo_profile());
- }
- if (build_options.generate_hlo_graph().has_value()) {
- execution_options.mutable_debug_options()->set_xla_generate_hlo_graph(
- build_options.generate_hlo_graph().value());
- }
- if (build_options.dump_optimized_hlo_proto_to().has_value()) {
- execution_options.mutable_debug_options()
- ->set_xla_dump_optimized_hlo_proto_to(
- build_options.dump_optimized_hlo_proto_to().value());
- }
- if (build_options.dump_unoptimized_hlo_proto_to().has_value()) {
- execution_options.mutable_debug_options()
- ->set_xla_dump_unoptimized_hlo_proto_to(
- build_options.dump_unoptimized_hlo_proto_to().value());
- }
- if (build_options.dump_per_pass_hlo_proto_to().has_value()) {
- execution_options.mutable_debug_options()
- ->set_xla_dump_per_pass_hlo_proto_to(
- build_options.dump_per_pass_hlo_proto_to().value());
+ if (build_options.has_debug_options()) {
+ *execution_options.mutable_debug_options() = build_options.debug_options();
}
if (build_options.result_layout() != nullptr) {
*execution_options.mutable_shape_with_output_layout() =
@@ -128,12 +108,6 @@
LayoutUtil::SetToDefaultLayout(
execution_options.mutable_shape_with_output_layout());
}
-
- for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) {
- execution_options.mutable_debug_options()->add_xla_disable_hlo_passes(
- disabled_pass);
- }
-
return execution_options;
}
@@ -145,7 +119,7 @@
const ExecutableBuildOptions& build_options) {
const HloModuleProto& proto = computation.proto();
TF_RET_CHECK(proto.has_host_program_shape());
- const ProgramShape& program_shape = proto.host_program_shape();
+ ProgramShape program_shape(proto.host_program_shape());
// Validate incoming layouts.
if (argument_layouts.size() != program_shape.parameters_size()) {
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index ec52a24..972a5b9 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -113,6 +113,13 @@
return Status::OK();
}
+Status LogicalBufferAnalysis::HandleAddDependency(
+ HloInstruction* add_dependency) {
+ // AddDependency just forwards the value of its zero-th operand and does not
+ // create buffers.
+ return Status::OK();
+}
+
Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) {
// The top-level buffer (index={}) for kCopy is newly created, but all other
// buffers (in the case of a tuple shape) come from the operand
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
index 81f524d..7ffca94 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
@@ -64,6 +64,7 @@
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
+ Status HandleAddDependency(HloInstruction* add_dependency) override;
// A map from the buffer ID to the logical buffer
std::vector<std::unique_ptr<LogicalBuffer>> logical_buffers_;
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 13fd6bc..c4b0a5c 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -658,9 +658,9 @@
// replica 0.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(request.computation().host_program_shape(),
- replicated_arguments.front(),
- request.execution_options()));
+ CreateModuleConfig(
+ ProgramShape{request.computation().host_program_shape()},
+ replicated_arguments.front(), request.execution_options()));
VLOG(3)
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -824,7 +824,7 @@
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
- CreateModuleConfig(arg->computation().host_program_shape(),
+ CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
argument_shapes, &arg->execution_options()));
VLOG(3) << "Compile created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@@ -1072,7 +1072,7 @@
"constant computation may not depend on any parameters.");
}
- ProgramShape program_shape = arg->computation().host_program_shape();
+ ProgramShape program_shape(arg->computation().host_program_shape());
TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
if (arg->has_output_layout()) {
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
@@ -1116,7 +1116,7 @@
return InvalidArgument("Program shape may not be empty.");
}
- HloModuleConfig config(arg->computation().host_program_shape());
+ HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
config.set_debug_options(arg->debug_options());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
CreateModuleFromProto(arg->computation(), config));
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 2bfc167..528d5c0 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -391,17 +391,6 @@
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
-/* static */ StatusOr<Shape> ShapeInference::InferAfterAllShape(
- absl::Span<const Shape* const> arg_shapes) {
- for (const Shape* arg_shape : arg_shapes) {
- if (arg_shape->element_type() != TOKEN) {
- return InvalidArgument(
- "Operands of token instructions must be TOKEN types.");
- }
- }
- return ShapeUtil::MakeTokenShape();
-}
-
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
const Shape& operand_shape, PrimitiveType new_element_type) {
auto old_element_type = operand_shape.element_type();
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 31ef4b2..d94385a 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -232,13 +232,6 @@
static StatusOr<Shape> InferConcatOpShape(
absl::Span<const Shape* const> arg_shapes, int64 dimension);
- // Infers the shape produced by a kAfterAll. Trivially this shape is always a
- // TOKEN shape. However, ShapeInference serves two purposes: inferring shapes
- // and checking operand shapes. This method verifies that the operand shapes
- // are all TOKENs.
- static StatusOr<Shape> InferAfterAllShape(
- absl::Span<const Shape* const> arg_shapes);
-
// Helper that validates the given operand shape can be converted to the
// target output_shape via a convert instruction -- the requirement is that
// the shape is identical except for the element type.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 96f3055..50d51ea 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -280,6 +280,13 @@
return Status::OK();
}
+Status TuplePointsToAnalysis::HandleAddDependency(
+ HloInstruction* add_dependency) {
+ // AddDependency just forwards the value of its zero-th operand.
+ CreateCopiedPointsToSet(add_dependency, add_dependency->operand(0));
+ return Status::OK();
+}
+
Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
// RecvDone aliases its input (Recv) tuple element {0} to element {0} of its
// output. The other indices ({} and {1}) define their own buffers.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index bcfcb38..0a1d564 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -252,6 +252,7 @@
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandleSend(HloInstruction* send) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
+ Status HandleAddDependency(HloInstruction* add_dependency) override;
string ToString() const;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 10ef2d3..561762b 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -264,6 +264,22 @@
UnorderedElementsAre(inner_tuple));
}
+TEST_F(TuplePointsToAnalysisTest, AddDependency) {
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
+ auto add_dependency = builder.AddInstruction(
+ HloInstruction::CreateAddDependency(constant, token));
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ auto& points_to_set = points_to_analysis_->GetPointsToSet(add_dependency);
+ EXPECT_EQ(1, points_to_set.size());
+ EXPECT_FALSE(points_to_set.IsAmbiguous());
+ EXPECT_TRUE(points_to_set.IsDistinct());
+ ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(), {constant});
+}
+
TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
// Create a tuple which contains duplicate elements.
auto builder = HloComputation::Builder(TestName());
diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc
new file mode 100644
index 0000000..d209389
--- /dev/null
+++ b/tensorflow/compiler/xla/shape.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape.h"
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+
+namespace xla {
+
+ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
+ for (const Shape& shape : program_shape_proto.parameters()) {
+ *add_parameters() = shape;
+ }
+ *mutable_result() = program_shape_proto.result();
+ for (const string& name : program_shape_proto.parameter_names()) {
+ add_parameter_names(name);
+ }
+}
+
+ProgramShapeProto ProgramShape::ToProto() const {
+ ProgramShapeProto proto;
+ for (const Shape& shape : parameters()) {
+ *proto.add_parameters() = shape;
+ }
+ *proto.mutable_result() = result();
+ for (const string& name : parameter_names()) {
+ proto.add_parameter_names(name);
+ }
+ return proto;
+}
+
+string ProgramShape::ToString() const {
+ std::vector<string> parameter_strings(parameters_size());
+ for (int i = 0; i < parameters_size(); ++i) {
+ parameter_strings[i] = absl::StrCat(
+ i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ",
+ ShapeUtil::HumanString(parameters(i)));
+ }
+ return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ",
+ ShapeUtil::HumanString(result()));
+}
+
+std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) {
+ out << program_shape.ToString() << "\n";
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h
new file mode 100644
index 0000000..c3aecb1
--- /dev/null
+++ b/tensorflow/compiler/xla/shape.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
+#define TENSORFLOW_COMPILER_XLA_SHAPE_H_
+
+#include <string>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Shape of the parameters and output of an XLA computation. This is analogous
+// to a traditional function signature.
+class ProgramShape {
+ public:
+ ProgramShape() = default;
+
+ // Creates a ProgramShape from a ProgramShapeProto protobuf.
+ explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
+
+ // Returns a proto representation of the object.
+ ProgramShapeProto ToProto() const;
+
+ string ToString() const;
+
+ // The following methods mirror the protobuf generated code interface for the
+ // message ProgramShapeProto. This enabled easy migration of this data
+ // structure from a proto to a proper C++ class.
+ // TODO(b/29771030): Replace or augment these methods with a more ergonomic
+ // interface.
+
+ // Methods for accessing and manipulating the Shape of the parameters.
+ int parameters_size() const { return parameters_.size(); }
+ const Shape& parameters(int index) const { return parameters_.at(index); }
+ Shape* mutable_parameters(int index) { return ¶meters_.at(index); }
+ Shape* add_parameters() {
+ parameters_.emplace_back();
+ return ¶meters_.back();
+ }
+ void clear_parameters() { parameters_.clear(); }
+ const std::vector<Shape>& parameters() const { return parameters_; }
+ std::vector<Shape>* mutable_parameters() { return ¶meters_; }
+
+ // Methods for accessing and manipulating the Shape of the result.
+ const Shape& result() const { return result_; }
+ Shape* mutable_result() { return &result_; }
+ void clear_result() { result_.Clear(); }
+
+ // Methods for accessing and manipulating the names of the parameters.
+ int parameter_names_size() const { return parameter_names_.size(); }
+ const string& parameter_names(int index) const {
+ return parameter_names_.at(index);
+ }
+ void set_parameter_names(int index, const string& value) {
+ parameter_names_.at(index) = value;
+ }
+ string* mutable_parameter_names(int index) {
+ return ¶meter_names_.at(index);
+ }
+ void add_parameter_names(const string& value) {
+ parameter_names_.push_back(value);
+ }
+ string* add_parameter_names() {
+ parameter_names_.push_back("");
+ return ¶meter_names_.back();
+ }
+ void clear_parameter_names() { parameter_names_.clear(); }
+ const std::vector<string>& parameter_names() const {
+ return parameter_names_;
+ }
+ std::vector<string>* mutable_parameter_names() { return ¶meter_names_; }
+
+ string ShortDebugString() const { return ToProto().ShortDebugString(); }
+ string DebugString() const { return ToProto().DebugString(); }
+
+ private:
+ // The shapes of the parameters of the computation represented by this object.
+ std::vector<Shape> parameters_;
+
+ // The names of the parameters of the computation represented by this object.
+ std::vector<string> parameter_names_;
+
+ // The shape of the result of the computation represented by this object.
+ Shape result_;
+};
+
+std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_
diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc
new file mode 100644
index 0000000..cc3a5eb
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape.h"
+
+#include <numeric>
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+TEST(ShapeTest, ProgramShapeToFromProto) {
+ ProgramShape program_shape;
+ *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3});
+ *program_shape.add_parameters() = ShapeUtil::MakeTokenShape();
+ *program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {});
+ *program_shape.add_parameters() = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}),
+ ShapeUtil::MakeShape(F32, {42, 42})});
+
+ *program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7});
+
+ program_shape.add_parameter_names("foo");
+ program_shape.add_parameter_names("bar");
+ program_shape.add_parameter_names("baz");
+ program_shape.add_parameter_names("qux qux");
+
+ // Create a copy of the program shape by round-tripping through a proto.
+ ProgramShape program_shape_copy(program_shape.ToProto());
+ ASSERT_EQ(program_shape.parameters_size(),
+ program_shape_copy.parameters_size());
+ for (int i = 0; i < program_shape.parameters_size(); ++i) {
+ EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i),
+ program_shape_copy.parameters(i)));
+ }
+
+ EXPECT_TRUE(
+ ShapeUtil::Equal(program_shape.result(), program_shape_copy.result()));
+
+ ASSERT_EQ(program_shape.parameter_names_size(),
+ program_shape_copy.parameter_names_size());
+ for (int i = 0; i < program_shape.parameter_names_size(); ++i) {
+ EXPECT_EQ(program_shape.parameter_names(i),
+ program_shape_copy.parameter_names(i));
+ }
+}
+
+TEST(ShapeTest, ProgramShapeToString) {
+ Shape opaque = ShapeUtil::MakeOpaqueShape();
+ Shape token = ShapeUtil::MakeTokenShape();
+ Shape scalar = ShapeUtil::MakeShape(F32, {});
+ Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
+ Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
+ Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
+ Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
+
+ ProgramShape prog = ShapeUtil::MakeProgramShape(
+ {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
+ EXPECT_EQ(
+ "((unknown): opaque[], "
+ "(unknown): f32[], "
+ "(unknown): u32[1,2], "
+ "(unknown): s32[3,4], "
+ "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
+ "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
+ "-> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
+ ShapeUtil::HumanString(prog));
+
+ prog.add_parameter_names("arg0");
+ prog.add_parameter_names("scalar");
+ prog.add_parameter_names("matrix");
+ prog.add_parameter_names("matrix2");
+ prog.add_parameter_names("tuple");
+ prog.add_parameter_names("nested_tuple");
+ EXPECT_EQ(
+ "(arg0: opaque[], "
+ "scalar: f32[], "
+ "matrix: u32[1,2], "
+ "matrix2: s32[3,4], "
+ "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
+ "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
+ "token[])) "
+ "-> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
+ ShapeUtil::HumanString(prog));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 7d011bf..b05ec20 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -563,6 +563,20 @@
HumanString(program_shape.result()));
}
+/* static */ string ShapeUtil::HumanString(
+ const ProgramShapeProto& program_shape_proto) {
+ std::vector<string> parameters;
+ for (auto& shape : program_shape_proto.parameters()) {
+ const int i = parameters.size();
+ parameters.push_back(StrCat(i < program_shape_proto.parameter_names_size()
+ ? program_shape_proto.parameter_names(i)
+ : "(unknown)",
+ ": ", HumanString(shape)));
+ }
+ return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
+ HumanString(program_shape_proto.result()));
+}
+
namespace {
// Parses shapes with simple recursive descent structure -- consumes from the
// front of s and passes that view recursively as required.
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 7f72e57..3796c5b 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -28,6 +28,7 @@
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -239,6 +240,7 @@
//
// (param_name: f32[42x12], ...) -> f32[24x42]
static string HumanString(const ProgramShape& program_shape);
+ static string HumanString(const ProgramShapeProto& program_shape_proto);
// Parses a ShapeUtil::HumanString-format shape string back into a shape
// object.
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 11b4933..ce6330a 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -575,37 +575,6 @@
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
"token[])",
ShapeUtil::HumanStringWithLayout(nested_tuple));
-
- ProgramShape prog = ShapeUtil::MakeProgramShape(
- {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
- EXPECT_EQ(
- "((unknown): opaque[], "
- "(unknown): f32[], "
- "(unknown): u32[1,2], "
- "(unknown): s32[3,4], "
- "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
- "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
- "-> "
- "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
- ShapeUtil::HumanString(prog));
-
- prog.add_parameter_names("arg0");
- prog.add_parameter_names("scalar");
- prog.add_parameter_names("matrix");
- prog.add_parameter_names("matrix2");
- prog.add_parameter_names("tuple");
- prog.add_parameter_names("nested_tuple");
- EXPECT_EQ(
- "(arg0: opaque[], "
- "scalar: f32[], "
- "matrix: u32[1,2], "
- "matrix2: s32[3,4], "
- "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
- "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
- "token[])) "
- "-> "
- "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
- ShapeUtil::HumanString(prog));
}
TEST(ShapeUtilTest, ForEachSubshapeArray) {
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 20493a3..2c18e2f 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -79,6 +79,7 @@
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
@@ -1291,6 +1292,7 @@
"enable_for_xla_interpreter",
],
deps = [
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1891,6 +1893,7 @@
xla_test(
name = "multioutput_fusion_test",
srcs = ["multioutput_fusion_test.cc"],
+ backends = ["gpu"],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index a5e9cfd..7e81905 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -1282,7 +1282,7 @@
}
template <typename T>
-class Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid : public ConvolutionTest {
+class Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid : public ConvolutionTest {
public:
void RunTest() {
XlaBuilder builder(TestName());
@@ -1341,8 +1341,72 @@
}
};
-TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, TestTypes);
-TYPED_TEST(Convolve2D_1x2x2x6_2x2x1x12_Grouped_Valid, Types) {
+TYPED_TEST_CASE(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, TestTypes);
+TYPED_TEST(Convolve2D_1x2x2x6_2x2x2x12_Grouped_Valid, Types) {
+ this->RunTest();
+}
+
+template <typename T>
+class Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid : public ConvolutionTest {
+ public:
+ void RunTest() {
+ XlaBuilder builder(TestName());
+ std::vector<int64> input_dims = {1, 2, 2, 1024};
+ std::vector<int64> filter_dims = {2, 2, 128, 512};
+ Shape input_shape = ShapeUtil::MakeShapeWithType<T>(input_dims);
+ Shape filter_shape = ShapeUtil::MakeShapeWithType<T>(filter_dims);
+ {
+ auto input = Parameter(&builder, 0, input_shape, "input");
+ auto filter = Parameter(&builder, 1, filter_shape, "filter");
+
+ // Tensorflow dimension numbers for 2D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_input_batch_dimension(0);
+ dnums.set_output_batch_dimension(0);
+ dnums.add_input_spatial_dimensions(1);
+ dnums.add_output_spatial_dimensions(1);
+ dnums.add_input_spatial_dimensions(2);
+ dnums.add_output_spatial_dimensions(2);
+ dnums.set_input_feature_dimension(3);
+ dnums.set_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
+ /*feature_group_count=*/8);
+ }
+
+ std::vector<T> input_elems(ShapeUtil::ElementsIn(input_shape),
+ static_cast<T>(1));
+
+ auto input_r1 = LiteralUtil::CreateR1<T>(input_elems);
+ auto input_r4 = input_r1.Reshape(input_dims).ConsumeValueOrDie();
+
+ std::vector<T> filter_elems(ShapeUtil::ElementsIn(filter_shape),
+ static_cast<T>(2));
+
+ auto filter_r1 = LiteralUtil::CreateR1<T>(filter_elems);
+ auto filter_r4 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie();
+
+ std::vector<T> output_elems(512, static_cast<T>(1024));
+ auto expected_r1 = LiteralUtil::CreateR1<T>(output_elems);
+ auto expected_r4 = expected_r1.Reshape({1, 1, 1, 512}).ConsumeValueOrDie();
+
+ auto input_literal =
+ client_->TransferToServer(input_r4).ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(filter_r4).ConsumeValueOrDie();
+
+ ComputeAndCompareLiteral(&builder, expected_r4,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+ }
+};
+
+TYPED_TEST_CASE(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, TestTypes);
+TYPED_TEST(Convolve2D_1x2x2x1024_2x2x128x512_Grouped_Valid, Types) {
this->RunTest();
}
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 6c0847a..25091b8 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -637,6 +637,76 @@
{x_data.get(), y_data.get()}, this->error_spec_);
}
+#ifndef XLA_TEST_BACKEND_CPU
+// TODO(b/74459949): failed on CPU on 2018-10-29.
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR3LhsR2Rhs) {
+ using T = TypeParam;
+
+ XlaBuilder builder(this->TestName());
+ auto x =
+ Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2}), "y");
+
+ DotDimensionNumbers dnums;
+ dnums.add_lhs_contracting_dimensions(1);
+ dnums.add_rhs_contracting_dimensions(1);
+ dnums.add_lhs_batch_dimensions(0);
+ dnums.add_rhs_batch_dimensions(0);
+
+ DotGeneral(x, y, dnums);
+
+ auto x_data =
+ this->client_
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
+ .ConsumeValueOrDie();
+
+ auto y_data = this->client_
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
+ {{1.0f, 0.0f}, {0.0f, 1.0f}}))
+ .ConsumeValueOrDie();
+
+ this->template ComputeAndCompareR2<T>(
+ &builder,
+ /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
+ this->error_spec_);
+}
+
+// TODO(b/74459949): failed on CPU on 2018-10-29.
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulR2LhsR3Rhs) {
+ using T = TypeParam;
+
+ XlaBuilder builder(this->TestName());
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2}), "x");
+ auto y =
+ Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2}), "y");
+
+ DotDimensionNumbers dnums;
+ dnums.add_lhs_contracting_dimensions(1);
+ dnums.add_rhs_contracting_dimensions(1);
+ dnums.add_lhs_batch_dimensions(0);
+ dnums.add_rhs_batch_dimensions(0);
+
+ DotGeneral(x, y, dnums);
+
+ auto x_data = this->client_
+ ->TransferToServer(LiteralUtil::CreateR2FromArray2D<T>(
+ {{1.0f, 0.0f}, {0.0f, 1.0f}}))
+ .ConsumeValueOrDie();
+
+ auto y_data =
+ this->client_
+ ->TransferToServer(LiteralUtil::CreateR3FromArray3D<T>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
+ .ConsumeValueOrDie();
+
+ this->template ComputeAndCompareR2<T>(
+ &builder,
+ /*expected=*/{{1.0f, 2.0f}, {7.0f, 8.0f}}, {x_data.get(), y_data.get()},
+ this->error_spec_);
+}
+#endif // XLA_TEST_BACKEND_CPU
+
XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
using T = TypeParam;
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index 5cf87e5..34c7dc7 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -55,7 +55,8 @@
client_->GetComputationShape(computation).ConsumeValueOrDie();
std::unique_ptr<ProgramShape> replayed_shape =
client_->GetComputationShape(replayed).ConsumeValueOrDie();
- ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
+ replayed_shape->ToProto()));
// Run it.
Literal literal =
@@ -87,7 +88,8 @@
client_->GetComputationShape(computation).ConsumeValueOrDie();
std::unique_ptr<ProgramShape> replayed_shape =
client_->GetComputationShape(replayed).ConsumeValueOrDie();
- ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
+ replayed_shape->ToProto()));
// Run it.
std::unique_ptr<GlobalData> x_data =
@@ -133,7 +135,8 @@
client_->GetComputationShape(computation).ConsumeValueOrDie();
std::unique_ptr<ProgramShape> replayed_shape =
client_->GetComputationShape(replayed).ConsumeValueOrDie();
- ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
+ replayed_shape->ToProto()));
// Run it.
Literal literal =
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 2f18036..8b4be66 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -15,6 +15,7 @@
#include <cmath>
+#include "absl/base/casts.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
@@ -28,65 +29,112 @@
namespace {
template <typename FloatT, typename GeneratorT>
-void PopulateWithRandomFloatingPointDataImpl(Literal* literal,
- std::minstd_rand0* engine,
- bool no_duplicates) {
- CHECK(engine != nullptr);
- CHECK_EQ(literal->shape().element_type(),
- primitive_util::NativeToPrimitiveType<FloatT>());
- if (no_duplicates) {
- // Duplicates may be generated if the number of elements in the literal
- // exceeds the number of positive values supported by the type.
- FloatT next_value = std::numeric_limits<FloatT>::min();
- for (FloatT& value : literal->data<FloatT>()) {
- value = next_value;
- next_value =
- std::nextafter(next_value, std::numeric_limits<FloatT>::max());
- }
- std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
- *engine);
- } else {
- std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
- for (FloatT& value : literal->data<FloatT>()) {
- value = static_cast<FloatT>(generator(*engine));
- }
+void PopulateWithRandomFloatingPointData(Literal* literal,
+ std::minstd_rand0* engine) {
+ std::uniform_real_distribution<GeneratorT> generator(-0.1f, 0.2f);
+ for (FloatT& value : literal->data<FloatT>()) {
+ value = static_cast<FloatT>(generator(*engine));
}
}
template <typename FloatT>
-void PopulateWithRandomFloatingPointData(Literal* literal,
- std::minstd_rand0* engine,
- bool no_duplicates) {
- CHECK(engine != nullptr);
- PopulateWithRandomFloatingPointDataImpl<FloatT, FloatT>(literal, engine,
- no_duplicates);
-}
+void PopulateWithIntNext(Literal* literal);
template <>
-void PopulateWithRandomFloatingPointData<half>(Literal* literal,
- std::minstd_rand0* engine,
- bool no_duplicates) {
- // no_duplicates is ignored for half types. Unique values can only be
- // generated for arrays with fewer than ~2**16 elements and no_duplicates is
- // best-effort anyway.
- CHECK(engine != nullptr);
- std::uniform_real_distribution<float> generator(-0.1f, 0.2f);
+void PopulateWithIntNext<half>(Literal* literal) {
+ // Duplicates may be generated if we don't have enough bits.
+ uint16 next_value = 0;
for (half& value : literal->data<half>()) {
- value = static_cast<half>(generator(*engine));
+ // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
+ // the sign bit. We could be less wasteful, but this is best-effort anyway.
+ uint16 exponent_msb = next_value & 0x4000;
+ value.x = (next_value & 0xBFFF) | (exponent_msb << 1);
+ next_value++;
}
}
template <>
-void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal,
- std::minstd_rand0* engine,
- bool no_duplicates) {
- // no_duplicates is ignored for bfloat types. Unique values can only be
- // generated for arrays with fewer than ~2**16 elements and no_duplicates is
- // best-effort anyway.
- CHECK(engine != nullptr);
- std::uniform_real_distribution<float> generator(-0.1f, 0.2f);
+void PopulateWithIntNext<bfloat16>(Literal* literal) {
+ // Duplicates may be generated if we don't have enough bits.
+ uint16 next_value = 0;
for (bfloat16& value : literal->data<bfloat16>()) {
- value = static_cast<bfloat16>(generator(*engine));
+ // Zero-out the MSB of the exponent to avoid Infs and NaNs, and put it into
+ // the sign bit. We could be less wasteful, but this is best-effort anyway.
+ uint16 exponent_msb = next_value & 0x4000;
+ value.value = (next_value & 0xBFFF) | (exponent_msb << 1);
+ next_value++;
+ }
+}
+
+template <typename FloatT>
+void PopulateWithNextAfter(Literal* literal) {
+ // Duplicates may be generated if the number of elements in the literal
+ // exceeds the number of positive values supported by the type.
+ float next_value = std::numeric_limits<float>::min();
+ for (float& value : literal->data<float>()) {
+ value = next_value;
+ next_value = std::nextafter(next_value, std::numeric_limits<float>::max());
+ }
+}
+
+template <typename FloatT,
+ typename std::enable_if<std::is_same<bfloat16, FloatT>::value ||
+ std::is_same<half, FloatT>::value,
+ int>::type = 0>
+void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
+ PopulateWithIntNext<FloatT>(literal);
+ std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
+ *engine);
+}
+
+template <typename FloatT,
+ typename std::enable_if<!std::is_same<bfloat16, FloatT>::value &&
+ !std::is_same<half, FloatT>::value,
+ int>::type = 0>
+void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) {
+ PopulateWithNextAfter<FloatT>(literal);
+ std::shuffle(literal->data<FloatT>().begin(), literal->data<FloatT>().end(),
+ *engine);
+}
+
+template <typename FloatT>
+void PopulateWithFloatingPointData(Literal* literal, std::minstd_rand0* engine,
+ bool no_duplicates) {
+ CHECK(engine != nullptr);
+ CHECK_EQ(literal->shape().element_type(),
+ primitive_util::NativeToPrimitiveType<FloatT>());
+ if (no_duplicates) {
+ PopulateWithNoDuplicateData<FloatT>(literal, engine);
+ } else {
+ PopulateWithRandomFloatingPointData<FloatT, FloatT>(literal, engine);
+ }
+}
+
+template <>
+void PopulateWithFloatingPointData<half>(Literal* literal,
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
+ CHECK(engine != nullptr);
+ CHECK_EQ(literal->shape().element_type(),
+ primitive_util::NativeToPrimitiveType<half>());
+ if (no_duplicates) {
+ PopulateWithNoDuplicateData<half>(literal, engine);
+ } else {
+ PopulateWithRandomFloatingPointData<half, float>(literal, engine);
+ }
+}
+
+template <>
+void PopulateWithFloatingPointData<bfloat16>(Literal* literal,
+ std::minstd_rand0* engine,
+ bool no_duplicates) {
+ CHECK(engine != nullptr);
+ CHECK_EQ(literal->shape().element_type(),
+ primitive_util::NativeToPrimitiveType<bfloat16>());
+ if (no_duplicates) {
+ PopulateWithNoDuplicateData<bfloat16>(literal, engine);
+ } else {
+ PopulateWithRandomFloatingPointData<bfloat16, float>(literal, engine);
}
}
@@ -135,20 +183,16 @@
Literal literal(shape);
switch (shape.element_type()) {
case BF16:
- PopulateWithRandomFloatingPointData<bfloat16>(&literal, engine,
- no_duplicates);
+ PopulateWithFloatingPointData<bfloat16>(&literal, engine, no_duplicates);
break;
case F16:
- PopulateWithRandomFloatingPointData<half>(&literal, engine,
- no_duplicates);
+ PopulateWithFloatingPointData<half>(&literal, engine, no_duplicates);
break;
case F32:
- PopulateWithRandomFloatingPointData<float>(&literal, engine,
- no_duplicates);
+ PopulateWithFloatingPointData<float>(&literal, engine, no_duplicates);
break;
case F64:
- PopulateWithRandomFloatingPointData<double>(&literal, engine,
- no_duplicates);
+ PopulateWithFloatingPointData<double>(&literal, engine, no_duplicates);
break;
case S8:
PopulateWithRandomIntegralData<int8>(&literal, engine, no_duplicates);
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index e066b3f..e8f5d7a 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -175,5 +175,28 @@
}
}
+XLA_TEST_F(TestUtilsTest, NoDuplicatesBfloat16) {
+ // Inputs which are sort keys in key/value sorts should have no duplicates.
+ auto module = ParseHloString(R"(
+HloModule sort, is_scheduled=true
+
+ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) {
+ %parameter.0 = bf16[2,1452]{1,0} parameter(0)
+ %parameter.1 = s32[2,1452]{1,0} parameter(1)
+ ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1}
+}
+)")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 2);
+ const Literal& key_arg = args[0];
+
+ absl::flat_hash_set<uint16> key_set;
+ for (const bfloat16& value : key_arg.data<bfloat16>()) {
+ EXPECT_TRUE(key_set.insert(absl::bit_cast<uint16>(value)).second);
+ }
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
index a2b7c26..601c6b0 100644
--- a/tensorflow/compiler/xla/tests/token_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -16,6 +16,7 @@
#include <array>
#include "absl/strings/str_cat.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -108,26 +109,6 @@
::testing::HasSubstr("Entry parameter 0 is or contains a token shape"));
}
-XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
- std::unique_ptr<HloModule> module = CreateNewUnverifiedModule();
- auto builder = HloComputation::Builder(TestName());
- auto param = builder.AddInstruction(
- HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
- builder.AddInstruction(HloInstruction::CreateAfterAll({param}));
- builder.AddInstruction(
- HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(123)));
- module->AddEntryComputation(builder.Build());
-
- Status status =
- HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false)
- .Run(module.get())
- .status();
- ASSERT_IS_NOT_OK(status);
- EXPECT_THAT(status.error_message(),
- ::testing::HasSubstr(
- "Operands of token instructions must be TOKEN types"));
-}
-
XLA_TEST_F(TokenHloTest, TokenInWhileLoop) {
// Thread a token around a while loop. Token is created and consumed by a
// AfterAll instruction in the while body.
@@ -220,5 +201,95 @@
}
}
+XLA_TEST_F(TokenHloTest, AddDependency) {
+ string module_string = R"(
+HloModule AddDependency, is_scheduled=true
+
+// Computes (p0 + 42) * (-p1)
+// where there is a dependency from the add to the negation using a token
+// with after-all and add-dependency instructions.
+ENTRY %AddDependency (p0: f32[], p1: f32[]) -> f32[] {
+ %p0 = f32[] parameter(0)
+ %p1 = f32[] parameter(1)
+
+ %forty_two = f32[] constant(42.0)
+ %add = f32[] add(f32[] %p0, f32[] %forty_two)
+ %token = token[] after-all(f32[] %add)
+ %p1_after_token = f32[] add-dependency(f32[] %p1, token[] %token)
+ %neg = f32[] negate(f32[] %p1_after_token)
+ ROOT %product = f32[] multiply(f32[] %add, f32[] %neg)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(module_string, GetModuleConfigForTest()));
+ auto p0 = LiteralUtil::CreateR0<float>(10.0);
+ auto p1 = LiteralUtil::CreateR0<float>(3.0);
+ auto expected = LiteralUtil::CreateR0<float>(-156.0);
+ EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1}));
+}
+
+XLA_TEST_F(TokenHloTest, AddDependencyOfConstant) {
+ string module_string = R"(
+HloModule AddDependencyOfConstant, is_scheduled=true
+
+ENTRY %AddDependency (p0: f32[]) -> f32[] {
+ %p0 = f32[] parameter(0)
+ %forty_two = f32[] constant(42.0)
+ %token = token[] after-all(f32[] %p0)
+ %forty_two_after_token = f32[] add-dependency(f32[] %forty_two, token[] %token)
+ ROOT %product = f32[] multiply(f32[] %p0, f32[] %forty_two_after_token)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(module_string, GetModuleConfigForTest()));
+ auto p0 = LiteralUtil::CreateR0<float>(10.0);
+ auto expected = LiteralUtil::CreateR0<float>(420.0);
+ EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0}));
+}
+
+XLA_TEST_F(TokenHloTest, AddDependencyAsRoot) {
+ string module_string = R"(
+HloModule AddDependencyAsRoot, is_scheduled=true
+ENTRY %AddDependency (p: f32[3]) -> f32[3] {
+ %p = f32[3] parameter(0)
+ %neg = f32[3] negate(f32[3] %p)
+ %token = token[] after-all()
+ ROOT %add_dep = f32[3] add-dependency(f32[3] %neg, token[] %token)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(module_string, GetModuleConfigForTest()));
+ auto input = LiteralUtil::CreateR1<float>({1.0, 3.0, 7.0});
+ auto expected = LiteralUtil::CreateR1<float>({-1.0, -3.0, -7.0});
+ EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&input}));
+}
+
+XLA_TEST_F(TokenHloTest, TupleShapedAddDependency) {
+ string module_string = R"(
+HloModule TupleShapedAddDependency, is_scheduled=true
+ENTRY %TupleShapedAddDependency (p0: f32[3], p1: f32[3]) -> f32[3] {
+ %p0 = f32[3] parameter(0)
+ %p1 = f32[3] parameter(1)
+ %forty_two = f32[] constant(42.0)
+ %token = token[] after-all()
+ %tuple = (f32[3], token[], f32[3], f32[]) tuple(f32[3] %p0, token[] %token, f32[3] %p1, f32[] %forty_two)
+ %add_dep = (f32[3], token[], f32[3], f32[]) add-dependency((f32[3], token[], f32[3], f32[]) %tuple, token[] %token)
+ %elem0 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=0
+ %elem2 = f32[3] get-tuple-element((f32[3], token[], f32[3], f32[]) %add_dep), index=2
+ ROOT %diff = f32[3] subtract(f32[3] %elem0, f32[3] %elem2)
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ ParseHloString(module_string, GetModuleConfigForTest()));
+ auto p0 = LiteralUtil::CreateR1<float>({3.0, 3.0, 47.0});
+ auto p1 = LiteralUtil::CreateR1<float>({1.0, -2.0, 2.0});
+ auto expected = LiteralUtil::CreateR1<float>({2.0, 5.0, 45.0});
+ EXPECT_EQ(expected, ExecuteNoHloPasses(std::move(module), {&p0, &p1}));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index ca036f1..e57d072 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -157,10 +157,12 @@
TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice(
stream_ptr.get(), Literal::CreateFromShape(rhs_arg_shape), rhs_arg));
+ ExecutableBuildOptions build_options;
+ build_options.mutable_debug_options()->set_xla_hlo_profile(true);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<LocalExecutable> local_executable,
client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape},
- ExecutableBuildOptions().set_hlo_profile(true)));
+ build_options));
Executable* executable = local_executable->executable();
HloExecutionProfile hlo_execution_profile(
@@ -208,7 +210,7 @@
string profile_output;
ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape,
rhs_shape);
-
+ VLOG(4) << "Profile Output:\n" << profile_output;
std::vector<string> profile_output_lines =
absl::StrSplit(profile_output, '\n');
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 683ccc4..27ef86a 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -183,7 +183,7 @@
// Shape of the parameters and output of a computation (like a traditional
// function signature).
-message ProgramShape {
+message ProgramShapeProto {
repeated Shape parameters = 1;
Shape result = 2;
repeated string parameter_names = 3;
diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD
index 2ff9791..2dae746 100644
--- a/tensorflow/compiler/xrt/BUILD
+++ b/tensorflow/compiler/xrt/BUILD
@@ -22,6 +22,7 @@
deps = [
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
],
)
@@ -32,20 +33,25 @@
"xrt_compilation_cache.cc",
"xrt_device.cc",
"xrt_state.cc",
+ "xrt_util.cc",
],
hdrs = [
"xrt_compilation_cache.h",
"xrt_device.h",
"xrt_state.h",
+ "xrt_util.h",
],
deps = [
"//tensorflow/compiler/jit:xla_device",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:device_memory_allocator",
diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
index dc62cf7..1603b45 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
+++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
@@ -33,6 +33,7 @@
#include "tensorflow/compiler/xrt/xrt.pb.h"
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
#include "tensorflow/compiler/xrt/xrt_device.h"
+#include "tensorflow/compiler/xrt/xrt_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
@@ -117,6 +118,10 @@
build_options.set_device_ordinal(client->default_device_ordinal());
build_options.set_result_layout(config.program_shape().result());
build_options.set_device_allocator(device_ref.backend()->memory_allocator());
+ if (config.has_debug_options()) {
+ *build_options.mutable_debug_options() =
+ BuildXlaDebugOptions(config.debug_options());
+ }
VLOG(1) << "Building executable";
auto compile_result =
@@ -174,11 +179,12 @@
ctx->set_output(0, handle_output);
xla::LocalExecutable* executable = entry->get().get_executable();
- xla::ProgramShape program_shape = executable->executable()
- ->module()
- .config()
- .entry_computation_layout()
- .ComputeProgramShape();
+ xla::ProgramShapeProto program_shape = executable->executable()
+ ->module()
+ .config()
+ .entry_computation_layout()
+ .ComputeProgramShape()
+ .ToProto();
Tensor program_shape_output(DT_STRING, TensorShape({1}));
program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
ctx->set_output(1, program_shape_output);
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index 25464b5..7e73db9 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -411,7 +411,7 @@
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
- xla::ProgramShape program_shape;
+ xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
}
@@ -465,7 +465,7 @@
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
- xla::ProgramShape program_shape;
+ xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
}
@@ -510,7 +510,7 @@
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
{c_handle.program_shape}, {release}, &outputs));
- xla::ProgramShape program_shape;
+ xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 1);
@@ -520,7 +520,7 @@
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
xla::ProgramShape xla_program_shape =
- XlaCompiledProgramShape(xla_computation, *shapes);
+ XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
EXPECT_TRUE(xla::LayoutUtil::Equal(
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
@@ -739,7 +739,7 @@
auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
- xla::ProgramShape program_shape;
+ xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
EXPECT_TRUE(
diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto
index 6ab77fb..e149f2f 100644
--- a/tensorflow/compiler/xrt/xrt.proto
+++ b/tensorflow/compiler/xrt/xrt.proto
@@ -3,6 +3,7 @@
package xrt;
import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
+import "tensorflow/compiler/xla/xla.proto";
import "tensorflow/compiler/xla/xla_data.proto";
import "tensorflow/compiler/xla/service/hlo.proto";
@@ -36,16 +37,18 @@
tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
// The arg/result shapes for the whole computation.
- xla.ProgramShape program_shape = 4;
+ xla.ProgramShapeProto program_shape = 4;
// The arg/result shapes for each core of a model-parallel
// computation. per_core_args_and_result_shapes is optional for a
// single-core computation.
- repeated xla.ProgramShape per_core_program_shape = 5;
+ repeated xla.ProgramShapeProto per_core_program_shape = 5;
// Describes how replicated computation instances should be assigned to
// devices. There are num_cores_per_replica computations, and each one will be
// sent and executed to the set of replica device numbers described in the
// DeviceAssignment proto.
DeviceAssignment device_assignment = 6;
+ // The debugging options to be passed to the XLA compilation process.
+ xla.DebugOptions debug_options = 7;
}
// Options and XLA computation for a compilation.
diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc
new file mode 100644
index 0000000..3ef8bed
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_util.cc
@@ -0,0 +1,76 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xrt/xrt_util.h"
+
+#include <stdlib.h>
+#include <string.h>
+
+#include "tensorflow/compiler/xla/debug_options_flags.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace {
+
+bool DebugOptionsPassThroughEnabled() {
+ const char* env = getenv("TF_XLA_DEBUG_OPTIONS_PASSTHROUGH");
+ bool enabled =
+ env != nullptr && (strcmp(env, "1") == 0 || strcmp(env, "true") == 0);
+ if (enabled) {
+ LOG(WARNING) << "Passing through XLA debug options!";
+ } else {
+ LOG(WARNING) << "TF_XLA_DEBUG_OPTIONS_PASSTHROUGH not set, not all options "
+ "will be retained";
+ }
+ return enabled;
+}
+
+string SafeDebugPath(const string& path) {
+ if (path.empty() || path.compare(0, 5, "gs://") == 0 ||
+ path.compare(0, 11, "bigstore://") == 0) {
+ return path;
+ }
+ LOG(WARNING) << "Invalid config path (will be dropped): " << path;
+ return string();
+}
+
+} // namespace
+
+xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options) {
+ static const bool options_passthrough = DebugOptionsPassThroughEnabled();
+ if (options_passthrough) {
+ return ref_options;
+ }
+ xla::DebugOptions options = xla::GetDebugOptionsFromFlags();
+ options.set_xla_generate_hlo_text_to(
+ SafeDebugPath(ref_options.xla_generate_hlo_text_to()));
+ options.set_xla_dump_optimized_hlo_proto_to(
+ SafeDebugPath(ref_options.xla_dump_optimized_hlo_proto_to()));
+ options.set_xla_dump_computations_to(
+ SafeDebugPath(ref_options.xla_dump_computations_to()));
+ options.set_xla_dump_executions_to(
+ SafeDebugPath(ref_options.xla_dump_executions_to()));
+ for (auto& pass : ref_options.xla_disable_hlo_passes()) {
+ options.add_xla_disable_hlo_passes(pass);
+ }
+ options.set_xla_dump_unoptimized_hlo_proto_to(
+ SafeDebugPath(ref_options.xla_dump_unoptimized_hlo_proto_to()));
+ options.set_xla_dump_per_pass_hlo_proto_to(
+ SafeDebugPath(ref_options.xla_dump_per_pass_hlo_proto_to()));
+ return options;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h
new file mode 100644
index 0000000..d9c05a7
--- /dev/null
+++ b/tensorflow/compiler/xrt/xrt_util.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utility functions in support of the XRT API.
+
+#ifndef TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_
+#define TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_
+
+#include "tensorflow/compiler/xla/xla.pb.h"
+
+namespace tensorflow {
+
+// Filters the debug options provided as argument according to the value of the
+// TF_XLA_DEBUG_OPTIONS_PASSTHROUGH environment variable. If such variable is
+// set to "1" or "true", the debug options will be returned as is. Otherwise
+// only a subset of them will be set in the returned ones, and all the paths
+// contained in it, will be limited to gs:// and bigstore:// ones.
+xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XRT_XRT_UTIL_H_
diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
index 13215ff..8b6ed9f 100644
--- a/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
+++ b/tensorflow/contrib/bayesflow/python/kernel_tests/monte_carlo_test.py
@@ -81,7 +81,7 @@
# Compute E_p[X_1 * X_2 > 0], with X_i the ith component of X ~ p(x).
# Should equal 1/2 because p is a spherical Gaussian centered at (0, 0).
def indicator(x):
- x1_times_x2 = math_ops.reduce_prod(x, reduction_indices=[-1])
+ x1_times_x2 = math_ops.reduce_prod(x, axis=[-1])
return 0.5 * (math_ops.sign(x1_times_x2) + 1.0)
prob = mc.expectation_importance_sampler(
diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
index 18d40fc..e83a548 100644
--- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
+++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py
@@ -353,12 +353,12 @@
def _sample_mean(values):
"""Mean over sample indices. In this module this is always [0]."""
- return math_ops.reduce_mean(values, reduction_indices=[0])
+ return math_ops.reduce_mean(values, axis=[0])
def _sample_max(values):
"""Max over sample indices. In this module this is always [0]."""
- return math_ops.reduce_max(values, reduction_indices=[0])
+ return math_ops.reduce_max(values, axis=[0])
def _get_samples(dist, z, n, seed):
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
index f083ce6..e95dc57 100644
--- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.cc
@@ -366,6 +366,39 @@
return MakeUnique<MutateRowsResponse>(request.entries_size());
}
+std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
+ google::bigtable::v2::MutateRowResponse>>
+BigtableTestClient::AsyncMutateRow(
+ grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowRequest const& request,
+ grpc::CompletionQueue* cq) {
+ LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
+ << "(); this will likely cause a crash!";
+ return nullptr;
+}
+
+std::unique_ptr<::grpc::ClientAsyncReaderInterface<
+ ::google::bigtable::v2::SampleRowKeysResponse>>
+BigtableTestClient::AsyncSampleRowKeys(
+ ::grpc::ClientContext* context,
+ const ::google::bigtable::v2::SampleRowKeysRequest& request,
+ ::grpc::CompletionQueue* cq, void* tag) {
+ LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
+ << "(); this will likely cause a crash!";
+ return nullptr;
+}
+
+std::unique_ptr<::grpc::ClientAsyncReaderInterface<
+ ::google::bigtable::v2::MutateRowsResponse>>
+BigtableTestClient::AsyncMutateRows(
+ ::grpc::ClientContext* context,
+ const ::google::bigtable::v2::MutateRowsRequest& request,
+ ::grpc::CompletionQueue* cq, void* tag) {
+ LOG(WARNING) << "Call to InMemoryDataClient::" << __func__
+ << "(); this will likely cause a crash!";
+ return nullptr;
+}
+
std::shared_ptr<grpc::Channel> BigtableTestClient::Channel() {
LOG(WARNING) << "Call to InMemoryDataClient::Channel(); this will likely "
"cause a crash!";
diff --git a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h
index dac2b16..c4a1f06 100644
--- a/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h
+++ b/tensorflow/contrib/bigtable/kernels/test_kernels/bigtable_test_client.h
@@ -61,6 +61,25 @@
MutateRows(grpc::ClientContext* context,
google::bigtable::v2::MutateRowsRequest const& request) override;
+ std::unique_ptr<grpc::ClientAsyncResponseReaderInterface<
+ google::bigtable::v2::MutateRowResponse>>
+ AsyncMutateRow(grpc::ClientContext* context,
+ google::bigtable::v2::MutateRowRequest const& request,
+ grpc::CompletionQueue* cq) override;
+
+ std::unique_ptr<::grpc::ClientAsyncReaderInterface<
+ ::google::bigtable::v2::SampleRowKeysResponse>>
+ AsyncSampleRowKeys(
+ ::grpc::ClientContext* context,
+ const ::google::bigtable::v2::SampleRowKeysRequest& request,
+ ::grpc::CompletionQueue* cq, void* tag) override;
+
+ std::unique_ptr<::grpc::ClientAsyncReaderInterface<
+ ::google::bigtable::v2::MutateRowsResponse>>
+ AsyncMutateRows(::grpc::ClientContext* context,
+ const ::google::bigtable::v2::MutateRowsRequest& request,
+ ::grpc::CompletionQueue* cq, void* tag) override;
+
std::shared_ptr<grpc::Channel> Channel() override;
private:
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 99ecded..a178820 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -428,6 +428,7 @@
learner_config,
examples_per_layer,
quantiles,
+ label_dimension=1,
num_trees=None,
feature_columns=None,
weight_column_name=None,
@@ -448,6 +449,10 @@
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
quantiles: a list of quantiles for the loss, each between 0 and 1.
+ label_dimension: Dimension of regression label. This is the size
+ of the last dimension of the labels `Tensor` (typically, this has shape
+ `[batch_size, label_dimension]`). When label_dimension>1, it is
+ recommended to use multiclass strategy diagonal hessian or full hessian.
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
@@ -489,9 +494,11 @@
loss_fn=functools.partial(
losses.per_example_quantile_regression_loss, quantile=quantile),
link_fn=array_ops.identity,
- logit_dimension=1)
+ logit_dimension=label_dimension)
return head
+ learner_config.num_classes = max(2, label_dimension)
+
super(GradientBoostedDecisionTreeQuantileRegressor, self).__init__(
model_fn=model.model_builder,
params={
@@ -548,6 +555,7 @@
# Core..QuantileRegressor directly,
def core_quantile_regression_head(
quantiles,
+ label_dimension=1,
weight_column=None,
loss_reduction=core_losses.Reduction.SUM_OVER_NONZERO_WEIGHTS):
"""Core head for quantile regression problems."""
@@ -562,7 +570,7 @@
# pylint:disable=protected-access
head_fn = core_head_lib._regression_head(
- label_dimension=1,
+ label_dimension=label_dimension,
loss_fn=loss_fn,
loss_reduction=loss_reduction,
weight_column=weight_column)
@@ -747,6 +755,7 @@
learner_config,
examples_per_layer,
quantiles,
+ label_dimension=1,
num_trees=None,
feature_columns=None,
weight_column_name=None,
@@ -766,6 +775,10 @@
layer. It can also be a function that computes the number of examples
based on the depth of the layer that's being built.
quantiles: a list of quantiles for the loss, each between 0 and 1.
+ label_dimension: Dimension of regression label. This is the size
+ of the last dimension of the labels `Tensor` (typically, this has shape
+ `[batch_size, label_dimension]`). When label_dimension>1, it is
+ recommended to use multiclass strategy diagonal hessian or full hessian.
num_trees: An int, number of trees to build.
feature_columns: A list of feature columns.
weight_column_name: Name of the column for weights, or None if not
@@ -799,18 +812,31 @@
mode=mode,
config=config,
params={
- 'head': core_quantile_regression_head(quantiles[0]),
- 'feature_columns': feature_columns,
- 'learner_config': learner_config,
- 'num_trees': num_trees,
- 'weight_column_name': weight_column_name,
- 'examples_per_layer': examples_per_layer,
- 'center_bias': center_bias,
- 'logits_modifier_function': logits_modifier_function,
- 'use_core_libs': True,
- 'output_leaf_index': output_leaf_index,
- 'override_global_step_value': None,
- 'num_quantiles': num_quantiles,
+ 'head':
+ core_quantile_regression_head(
+ quantiles[0], label_dimension=label_dimension),
+ 'feature_columns':
+ feature_columns,
+ 'learner_config':
+ learner_config,
+ 'num_trees':
+ num_trees,
+ 'weight_column_name':
+ weight_column_name,
+ 'examples_per_layer':
+ examples_per_layer,
+ 'center_bias':
+ center_bias,
+ 'logits_modifier_function':
+ logits_modifier_function,
+ 'use_core_libs':
+ True,
+ 'output_leaf_index':
+ output_leaf_index,
+ 'override_global_step_value':
+ None,
+ 'num_quantiles':
+ num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 7863b5a..ee052ac 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -81,7 +81,7 @@
_QUANTILE_REGRESSION_SIZE = 1000
-def _quantile_regression_input_fns():
+def _quantile_regression_input_fns(two_dimension=False):
# The data generation is taken from
# http://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html
np.random.seed(1)
@@ -90,20 +90,28 @@
"""The function to predict."""
return x * np.sin(x)
+ def g(x):
+ """The function to predict."""
+ return x * np.cos(x)
+
# Training data.
x = np.atleast_2d(np.random.uniform(0, 10.0,
size=_QUANTILE_REGRESSION_SIZE)).T
x = x.astype(np.float32)
# Labels.
- y = f(x).ravel()
+ if not two_dimension:
+ y = f(x).ravel()
+ else:
+ y = np.column_stack((f(x).ravel(), g(x).ravel()))
# Add random noise.
dy = 1.5 + 1.0 * np.random.random(y.shape)
noise = np.random.normal(0, dy)
y += noise
y_original = y.astype(np.float32)
- y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1)
+ if not two_dimension:
+ y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1)
train_input_fn = numpy_io.numpy_input_fn(
x=x,
@@ -439,6 +447,78 @@
self.assertTrue(frac_above_lower >= 0.92)
self.assertTrue(frac_above_lower <= 0.98)
+ # Multi-dimensional quantile regression.
+ def testQuantileRegressionMultiDimLabel(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 3
+ learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE
+ learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE
+ learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE
+ learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE
+ learner_config.regularization.tree_complexity = (
+ 1.0 / _QUANTILE_REGRESSION_SIZE)
+
+ train_input_fn, test_input_fn, y = _quantile_regression_input_fns(
+ two_dimension=True)
+
+ # 95% percentile.
+ model_upper = estimator.GradientBoostedDecisionTreeQuantileRegressor(
+ quantiles=[0.95],
+ learner_config=learner_config,
+ label_dimension=2,
+ num_trees=100,
+ examples_per_layer=_QUANTILE_REGRESSION_SIZE,
+ center_bias=False)
+
+ model_upper.fit(input_fn=train_input_fn, steps=1000)
+ result_iter = model_upper.predict(input_fn=test_input_fn)
+ upper = []
+ for prediction_dict in result_iter:
+ upper.append(prediction_dict["scores"])
+
+ count_below_upper = np.count_nonzero(upper > y, axis=0)
+ count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1))
+ frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3)
+ frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3)
+ frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3)
+ # +/- 3%
+ self.assertTrue(frac_below_upper_0 >= 0.92)
+ self.assertTrue(frac_below_upper_0 <= 0.98)
+ self.assertTrue(frac_below_upper_1 >= 0.92)
+ self.assertTrue(frac_below_upper_1 <= 0.98)
+ self.assertTrue(frac_both_below_upper >= 0.92)
+ self.assertTrue(frac_both_below_upper <= 0.98)
+
+ train_input_fn, test_input_fn, _ = _quantile_regression_input_fns(
+ two_dimension=True)
+ model_lower = estimator.GradientBoostedDecisionTreeQuantileRegressor(
+ quantiles=[0.05],
+ learner_config=learner_config,
+ label_dimension=2,
+ num_trees=100,
+ examples_per_layer=_QUANTILE_REGRESSION_SIZE,
+ center_bias=False)
+
+ model_lower.fit(input_fn=train_input_fn, steps=1000)
+ result_iter = model_lower.predict(input_fn=test_input_fn)
+ lower = []
+ for prediction_dict in result_iter:
+ lower.append(prediction_dict["scores"])
+
+ count_above_lower = np.count_nonzero(lower < y, axis=0)
+ count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1))
+ frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3)
+ frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3)
+ frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3)
+ # +/- 3%
+ self.assertTrue(frac_above_lower_0 >= 0.92)
+ self.assertTrue(frac_above_lower_0 <= 0.98)
+ self.assertTrue(frac_above_lower_1 >= 0.92)
+ self.assertTrue(frac_above_lower_1 <= 0.98)
+ self.assertTrue(frac_both_above_lower >= 0.92)
+ self.assertTrue(frac_both_above_lower <= 0.98)
+
class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
@@ -685,6 +765,79 @@
self.assertTrue(frac_above_lower >= 0.92)
self.assertTrue(frac_above_lower <= 0.98)
+ # Multi-dimensional quantile regression.
+ def testQuantileRegressionMultiDimLabel(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 3
+ learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE
+ learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE
+ learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE
+ learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE
+ learner_config.regularization.tree_complexity = (
+ 1.0 / _QUANTILE_REGRESSION_SIZE)
+
+ train_input_fn, test_input_fn, y = _quantile_regression_input_fns(
+ two_dimension=True)
+ y = y.reshape(_QUANTILE_REGRESSION_SIZE, 2)
+
+ # 95% percentile.
+ model_upper = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor(
+ quantiles=[0.95],
+ learner_config=learner_config,
+ num_trees=100,
+ label_dimension=2,
+ examples_per_layer=_QUANTILE_REGRESSION_SIZE,
+ center_bias=False)
+
+ model_upper.train(input_fn=train_input_fn, steps=1000)
+ result_iter = model_upper.predict(input_fn=test_input_fn)
+ upper = []
+ for prediction_dict in result_iter:
+ upper.append(prediction_dict["predictions"])
+
+ count_below_upper = np.count_nonzero(upper > y, axis=0)
+ count_both_below_upper = np.count_nonzero(np.prod(upper > y, axis=1))
+ frac_below_upper_0 = round(1. * count_below_upper[0] / len(y), 3)
+ frac_below_upper_1 = round(1. * count_below_upper[1] / len(y), 3)
+ frac_both_below_upper = round(1. * count_both_below_upper / len(y), 3)
+ # +/- 3%
+ self.assertTrue(frac_below_upper_0 >= 0.92)
+ self.assertTrue(frac_below_upper_0 <= 0.98)
+ self.assertTrue(frac_below_upper_1 >= 0.92)
+ self.assertTrue(frac_below_upper_1 <= 0.98)
+ self.assertTrue(frac_both_below_upper >= 0.92)
+ self.assertTrue(frac_both_below_upper <= 0.98)
+
+ train_input_fn, test_input_fn, _ = _quantile_regression_input_fns(
+ two_dimension=True)
+ model_lower = estimator.CoreGradientBoostedDecisionTreeQuantileRegressor(
+ quantiles=[0.05],
+ learner_config=learner_config,
+ num_trees=100,
+ label_dimension=2,
+ examples_per_layer=_QUANTILE_REGRESSION_SIZE,
+ center_bias=False)
+
+ model_lower.train(input_fn=train_input_fn, steps=1000)
+ result_iter = model_lower.predict(input_fn=test_input_fn)
+ lower = []
+ for prediction_dict in result_iter:
+ lower.append(prediction_dict["predictions"])
+
+ count_above_lower = np.count_nonzero(lower < y, axis=0)
+ count_both_aboce_lower = np.count_nonzero(np.prod(lower < y, axis=1))
+ frac_above_lower_0 = round(1. * count_above_lower[0] / len(y), 3)
+ frac_above_lower_1 = round(1. * count_above_lower[1] / len(y), 3)
+ frac_both_above_lower = round(1. * count_both_aboce_lower / len(y), 3)
+ # +/- 3%
+ self.assertTrue(frac_above_lower_0 >= 0.92)
+ self.assertTrue(frac_above_lower_0 <= 0.98)
+ self.assertTrue(frac_above_lower_1 >= 0.92)
+ self.assertTrue(frac_above_lower_1 <= 0.98)
+ self.assertTrue(frac_both_above_lower >= 0.92)
+ self.assertTrue(frac_both_above_lower <= 0.98)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py
index 54c4ff0..09b240a 100644
--- a/tensorflow/contrib/boosted_trees/examples/boston.py
+++ b/tensorflow/contrib/boosted_trees/examples/boston.py
@@ -90,13 +90,13 @@
(x_train, y_train), (x_test,
y_test) = tf.keras.datasets.boston_housing.load_data()
- train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train,
batch_size=FLAGS.batch_size,
num_epochs=None,
shuffle=True)
- eval_input_fn = tf.estimator.inputs.numpy_input_fn(
+ eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False)
feature_columns = [
diff --git a/tensorflow/contrib/boosted_trees/examples/boston_combined.py b/tensorflow/contrib/boosted_trees/examples/boston_combined.py
index e04b56a..d640af3 100644
--- a/tensorflow/contrib/boosted_trees/examples/boston_combined.py
+++ b/tensorflow/contrib/boosted_trees/examples/boston_combined.py
@@ -80,13 +80,13 @@
(x_train, y_train), (x_test,
y_test) = tf.keras.datasets.boston_housing.load_data()
- train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train,
batch_size=FLAGS.batch_size,
num_epochs=None,
shuffle=True)
- eval_input_fn = tf.estimator.inputs.numpy_input_fn(
+ eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False)
feature_columns = [
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index 4da2529..d26af58 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -119,7 +119,7 @@
def not_active_inputs():
return (constant_op.constant([], dtype=dtypes.int32),
- constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]),
+ constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]),
empty_gradients, empty_hessians)
def active_inputs():
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index a2f7080..386dc19 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -36,9 +36,9 @@
empty_hess_shape = [1] + hessian_shape.as_list()
empty_grad_shape = [1] + gradient_shape.as_list()
- empty_gradients = constant_op.constant(
+ empty_gradients = constant_op.constant_v1(
[], dtype=dtypes.float32, shape=empty_grad_shape)
- empty_hessians = constant_op.constant(
+ empty_hessians = constant_op.constant_v1(
[], dtype=dtypes.float32, shape=empty_hess_shape)
return empty_gradients, empty_hessians
@@ -486,8 +486,8 @@
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]
- indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2])
- values = array_ops.constant([], dtype=dtypes.int64)
+ indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2])
+ values = constant_op.constant_v1([], dtype=dtypes.int64)
gradient_shape = tensor_shape.scalar()
hessian_shape = tensor_shape.scalar()
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index 1fffbb5..0476bed 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -605,7 +605,7 @@
quantile_buckets, example_partition_ids, gradients,
hessians, weights, empty_gradients, empty_hessians):
"""Updates the state for dense split handler."""
- empty_float = constant_op.constant([], dtype=dtypes.float32)
+ empty_float = constant_op.constant_v1([], dtype=dtypes.float32)
quantile_values, quantile_weights = control_flow_ops.cond(
is_active[1], # For the next layer, this handler is inactive.
@@ -621,8 +621,8 @@
return (example_partition_ids, quantized_feature, gradients, hessians)
def not_ready_inputs_fn():
- return (constant_op.constant([], dtype=dtypes.int32),
- constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]),
+ return (constant_op.constant_v1([], dtype=dtypes.int32),
+ constant_op.constant_v1([[]], dtype=dtypes.int64, shape=[1, 2]),
empty_gradients, empty_hessians)
example_partition_ids, feature_ids, gradients, hessians = (
@@ -708,11 +708,11 @@
def quantiles_not_ready():
"""The subgraph for when the quantiles are not ready."""
- return (constant_op.constant([], dtype=dtypes.int32),
- constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]),
+ return (constant_op.constant_v1([], dtype=dtypes.int32),
+ constant_op.constant_v1([], dtype=dtypes.int64, shape=[1, 2]),
empty_gradients, empty_hessians)
- empty_float = constant_op.constant([], dtype=dtypes.float32)
+ empty_float = constant_op.constant_v1([], dtype=dtypes.float32)
handler_not_active = (constant_op.constant(
[], dtype=dtypes.int64, shape=[0, 2]), empty_float,
constant_op.constant([0, 1], dtype=dtypes.int64),
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 74b0ea6..4a1b528 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -39,9 +39,9 @@
empty_hess_shape = [1] + hessian_shape.as_list()
empty_grad_shape = [1] + gradient_shape.as_list()
- empty_gradients = constant_op.constant(
+ empty_gradients = constant_op.constant_v1(
[], dtype=dtypes.float32, shape=empty_grad_shape)
- empty_hessians = constant_op.constant(
+ empty_hessians = constant_op.constant_v1(
[], dtype=dtypes.float32, shape=empty_hess_shape)
return empty_gradients, empty_hessians
@@ -1476,9 +1476,9 @@
def testEmpty(self):
with self.cached_session() as sess:
- indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2])
+ indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2])
# No values in this feature column in this mini-batch.
- values = array_ops.constant([], dtype=dtypes.float32)
+ values = constant_op.constant_v1([], dtype=dtypes.float32)
sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1])
gradient_shape = tensor_shape.scalar()
@@ -1549,8 +1549,9 @@
sparse_column = array_ops.sparse_placeholder(dtypes.float32)
# We have two batches - at first, a sparse feature is empty.
- empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2])
- empty_values = array_ops.constant([], dtype=dtypes.float32)
+ empty_indices = constant_op.constant_v1([], dtype=dtypes.int64,
+ shape=[0, 2])
+ empty_values = constant_op.constant_v1([], dtype=dtypes.float32)
empty_sparse_column = sparse_tensor.SparseTensor(empty_indices,
empty_values, [4, 2])
empty_sparse_column = empty_sparse_column.eval(session=sess)
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 85020c5..9fdc2fc 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -897,9 +897,9 @@
empty_hess_shape = [1] + self._hessian_shape.as_list()
empty_grad_shape = [1] + self._gradient_shape.as_list()
- empty_gradients = constant_op.constant(
+ empty_gradients = constant_op.constant_v1(
[], dtype=dtypes.float32, shape=empty_grad_shape)
- empty_hessians = constant_op.constant(
+ empty_hessians = constant_op.constant_v1(
[], dtype=dtypes.float32, shape=empty_hess_shape)
active_handlers = array_ops.unstack(active_handlers, axis=0)
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py
index f8da20a..220e981 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py
@@ -65,9 +65,9 @@
below is this loss but squared in the region where the loss value < 1.
Args:
- labels: Rank 2 (N, 1) tensor of per-example labels.
+ labels: Rank 2 (N, D) tensor of per-example labels.
weights: Rank 2 (N, 1) tensor of per-example weights.
- predictions: Rank 2 (N, 1) tensor of per-example predictions.
+ predictions: Rank 2 (N, D) tensor of per-example predictions.
quantile: The quantile to use.
Returns:
@@ -119,8 +119,7 @@
labels = array_ops.expand_dims(labels, 1)
# Labels are indices of classes, convert them to one hot encodings.
target_one_hot = array_ops.one_hot(indices=labels, depth=num_classes)
- labels = math_ops.reduce_sum(
- input_tensor=target_one_hot, reduction_indices=[1])
+ labels = math_ops.reduce_sum(input_tensor=target_one_hot, axis=[1])
labels = math_ops.to_float(labels)
# Calculate softmax probabilities for each class.
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py
index 242c1e8..5418e26 100644
--- a/tensorflow/contrib/checkpoint/python/containers.py
+++ b/tensorflow/contrib/checkpoint/python/containers.py
@@ -46,6 +46,10 @@
self._maybe_initialize_checkpointable()
self._name_counts = {}
+ @property
+ def _values(self):
+ return [dep.ref for dep in self._checkpoint_dependencies]
+
def track(self, checkpointable, base_name):
"""Add a dependency on `checkpointable`.
diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py
index fd1263fe..ab0746a 100644
--- a/tensorflow/contrib/cluster_resolver/__init__.py
+++ b/tensorflow/contrib/cluster_resolver/__init__.py
@@ -24,7 +24,9 @@
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import SimpleClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import UnionClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.kubernetes_cluster_resolver import KubernetesClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.slurm_cluster_resolver import SlurmClusterResolver
+from tensorflow.contrib.cluster_resolver.python.training.tfconfig_cluster_resolver import TFConfigClusterResolver
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
# pylint: enable=wildcard-import,unused-import
@@ -35,6 +37,8 @@
'SimpleClusterResolver',
'UnionClusterResolver',
'GceClusterResolver',
+ 'KubernetesClusterResolver',
+ 'TFConfigClusterResolver',
'TPUClusterResolver',
'SlurmClusterResolver',
]
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index 1630f01..e456643 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -58,6 +58,7 @@
srcs_version = "PY2AND3",
deps = [
"//tensorflow/compiler/jit:xla_ops_py",
+ "//tensorflow/compiler/jit/ops:xla_ops_grad",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
index 335ac79..f867cd1 100644
--- a/tensorflow/contrib/compiler/xla.py
+++ b/tensorflow/contrib/compiler/xla.py
@@ -23,6 +23,7 @@
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.jit.ops import xla_ops
+from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
index 41258ed..6926c0d 100644
--- a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
+++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
@@ -74,8 +74,8 @@
if (constraints_shape.ndims is None or
proxy_constraints_shape.ndims is None or
- any([ii is None for ii in constraints_shape.as_list()]) or
- any([ii is None for ii in proxy_constraints_shape.as_list()])):
+ any(ii is None for ii in constraints_shape.as_list()) or
+ any(ii is None for ii in proxy_constraints_shape.as_list())):
raise ValueError(
"constraints and proxy_constraints must have fully-known shapes")
if constraints_shape != proxy_constraints_shape:
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 1e2c912..a268415 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -778,8 +778,7 @@
# Test opaque_params size lower bound
opaque_params_size_v = sess.run(opaque_params_size)
- min_params_size = (
- np.sum([x.size for x in ws]) + np.sum([x.size for x in bs]))
+ min_params_size = sum(x.size for x in ws) + np.sum(x.size for x in bs)
logging.info("min_parm_size: %d vs actual_opaque_param_size: %d",
min_params_size, opaque_params_size_v)
self.assertLessEqual(min_params_size, opaque_params_size_v)
@@ -853,8 +852,7 @@
# Test opaque_params size lower bound
opaque_params_size_v = sess.run(opaque_params_size)
- min_params_size = (
- np.sum([x.size for x in ws]) + np.sum([x.size for x in bs]))
+ min_params_size = sum(x.size for x in ws) + sum(x.size for x in bs)
logging.info("min_parm_size: %d vs actual_opaque_param_size: %d",
min_params_size, opaque_params_size_v)
self.assertLessEqual(min_params_size, opaque_params_size_v)
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 6cc93dc..7e1b406 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -1045,8 +1045,8 @@
# Min param size estimate = sum(weights.size) + sum(biases.size)
min_params_size = (
- np.sum(list(map(np.prod, rnn.canonical_weight_shapes))) +
- np.sum([sp[0] for sp in rnn.canonical_bias_shapes]))
+ sum(map(np.prod, rnn.canonical_weight_shapes)) +
+ sum(sp[0] for sp in rnn.canonical_bias_shapes))
opaque_params = rnn.trainable_variables[0]
with self.test_session(use_gpu=True, graph=ops.get_default_graph()):
diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
index 8bbcc7c..8e25637 100644
--- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
+++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py
@@ -21,6 +21,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@@ -322,7 +323,7 @@
raise ValueError("The last dimension of the inputs to `CudnnRNN` "
"should be defined. Found `None`.")
self._input_size = input_shape[-1].value
- self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size})
+ self.input_spec = input_spec.InputSpec(ndim=3, axes={-1: self._input_size})
self._set_scope(None)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index a938f86..8a8dc15 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -134,7 +134,7 @@
return tf.estimator.EstimatorSpec(mode, loss=loss)
if mode == tf.estimator.ModeKeys.TRAIN:
- train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss_fn())
+ train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
```
@@ -248,19 +248,17 @@
workers doing synchronous all-reduce training. In the following code snippet, we
start multi-worker training using `tf.estimator.train_and_evaluate`:
-
```python
def model_main():
- estimator = ...
distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
num_gpus_per_worker=2)
config = tf.estimator.RunConfig(train_distribute=distribution)
+ estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)
train_spec = tf.estimator.TrainSpec(input_fn=input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
```
-
**Note**: You don't have to set "TF\_CONFIG" manually if you use our provided
Kubernetes template.
@@ -327,13 +325,13 @@
On your laptop, you can run
```python
-estimator = ...
distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
num_gpus_per_worker=2)
config = tf.estimator.RunConfig(
experimental_distribute=tf.contrib.distribute.DistributeConfig(
train_distribute=distribution,
remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]}))
+estimator = tf.estimator.Estimator(model_fn=model_fn, config=config)
train_spec = tf.estimator.TrainSpec(input_fn=input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 2a595e7..38ce0b2 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -27,13 +27,13 @@
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
- "//tensorflow/python:device_util",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/distribute:device_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
@@ -49,28 +49,9 @@
srcs = ["mirrored_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:device",
- "//tensorflow/python:device_util",
- "//tensorflow/python:distribute",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:pywrap_tensorflow",
- "//tensorflow/python:tensor_util",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/distribute:cross_device_ops",
- "//tensorflow/python/distribute:multi_worker_util",
- "//tensorflow/python/distribute:reduce_util",
- "//tensorflow/python/distribute:shared_variable_creator",
+ "//tensorflow/python/distribute:distribute_lib",
+ "//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:values",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/eager:tape",
],
)
@@ -133,10 +114,10 @@
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:array_ops",
- "//tensorflow/python:distribute",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:reduce_util",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
@@ -175,11 +156,11 @@
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
- "//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
"//tensorflow/python:layers",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
@@ -200,10 +181,10 @@
":tpu_strategy",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/optimizer_v2:training",
- "//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:context",
"@absl_py//absl/testing:parameterized",
],
@@ -248,11 +229,11 @@
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
- "//tensorflow/python:distribute",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:layers",
"//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
+ "//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:values",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
@@ -435,6 +416,7 @@
"multi_and_single_gpu",
"no_oss", # http://b/119349471
"no_pip",
+ "tf_integration_test",
],
)
@@ -448,6 +430,7 @@
"multi_and_single_gpu",
"no_oss", # http://b/119349471
"no_pip",
+ "tf_integration_test",
],
)
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index f13cf26..617a95f 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -18,10 +18,13 @@
from __future__ import division
from __future__ import print_function
+import copy
+
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
@@ -29,7 +32,6 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
# TODO(yuefengz): support in-graph replication.
@@ -263,13 +265,15 @@
self._container_strategy(), self._num_gpus_per_worker, cluster_spec,
task_type, task_id)
- if not session_config:
- return
+ if session_config:
+ session_config.CopyFrom(self._update_config_proto(session_config))
+ def _update_config_proto(self, config_proto):
+ updated_config = copy.deepcopy(config_proto)
# Enable the scoped allocator optimization for CollectiveOps. This
# optimization converts many small all-reduces into fewer larger
# all-reduces.
- rewrite_options = session_config.graph_options.rewrite_options
+ rewrite_options = updated_config.graph_options.rewrite_options
rewrite_options.scoped_allocator_optimization = (
rewriter_config_pb2.RewriterConfig.ON)
# We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op =
@@ -279,7 +283,7 @@
rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
if not self._cluster_spec:
- return
+ return updated_config
assert self._task_type
assert self._task_id is not None
@@ -287,20 +291,22 @@
# Collective group leader is needed for collective ops to coordinate
# workers.
if "chief" in self._cluster_spec.jobs:
- session_config.experimental.collective_group_leader = (
+ updated_config.experimental.collective_group_leader = (
"/job:chief/replica:0/task:0")
else:
if "worker" not in self._cluster_spec.jobs:
raise ValueError(
"You must have `chief` or `worker` jobs in the `cluster_spec`.")
- session_config.experimental.collective_group_leader = (
+ updated_config.experimental.collective_group_leader = (
"/job:worker/replica:0/task:0")
# The device filters prevent communication between workers.
- del session_config.device_filters[:]
- session_config.device_filters.append(
+ del updated_config.device_filters[:]
+ updated_config.device_filters.append(
"/job:%s/task:%d" % (self._task_type, self._task_id))
+ return updated_config
+
@property
def experimental_between_graph(self):
return True
@@ -320,3 +326,8 @@
@property
def _num_replicas_in_sync(self):
return len(self._devices) * self._num_workers
+
+ # TODO(priyag): Delete this once all strategies use global batch size.
+ @property
+ def _global_batch_size(self):
+ return False
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index a47eef9..09239ff 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -26,6 +26,7 @@
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import cross_device_utils
@@ -56,9 +57,6 @@
collective_key_base = 0
def setUp(self):
- self._run_options = config_pb2.RunOptions()
- self._run_options.experimental.collective_graph_key = 6
-
# We use a different key_base for each test so that collective keys won't be
# reused.
# TODO(yuefengz, tucker): enable it to reuse collective keys in different
@@ -145,11 +143,10 @@
if context.num_gpus() < d.extended._num_gpus_per_worker:
return True
- sess.run(
- variables.global_variables_initializer(), options=self._run_options)
+ sess.run(variables.global_variables_initializer())
for i in range(10):
- b, a = sess.run((before_out, after_out), options=self._run_options)
+ b, a = sess.run((before_out, after_out))
if i == 0:
before, = b
after, = a
@@ -234,11 +231,9 @@
destinations='/cpu:0'))[0]
x = distribution.unwrap(x)[0]
- sess.run(
- variables.global_variables_initializer(), options=self._run_options)
+ sess.run(variables.global_variables_initializer())
- x_value, reduced_x_value = sess.run([x, reduced_x],
- options=self._run_options)
+ x_value, reduced_x_value = sess.run([x, reduced_x])
self.assertTrue(
np.allclose(x_value, reduced_x_value, atol=1e-5),
msg=('x_value = %r, reduced_x_value = %r' % (x_value,
@@ -249,7 +244,7 @@
expected_values):
distribution, master_target, config = self._get_test_object(
task_type, task_id, num_gpus)
- devices = distribution.worker_devices
+ devices = distribution.extended.worker_devices
with ops.Graph().as_default(), \
self.cached_session(config=config,
@@ -342,6 +337,32 @@
self._test_input_fn_iterator('worker', 1, num_gpus,
input_fn, expected_values)
+ def testUpdateConfigProto(self):
+ distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=2)
+ distribution.configure(
+ cluster_spec=self._cluster_spec, task_type='worker', task_id=1)
+
+ config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden'])
+ rewrite_options = config_proto.graph_options.rewrite_options
+ rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed')
+
+ new_config = distribution.update_config_proto(config_proto)
+
+ # Verify group leader
+ self.assertEqual('/job:worker/replica:0/task:0',
+ new_config.experimental.collective_group_leader)
+
+ # Verify device filters.
+ self.assertEqual(['/job:worker/task:1'], new_config.device_filters)
+
+ # Verify rewrite options.
+ new_rewrite_options = new_config.graph_options.rewrite_options
+ self.assertEqual(rewriter_config_pb2.RewriterConfig.ON,
+ new_rewrite_options.scoped_allocator_optimization)
+ self.assertEqual(['CollectiveReduce'],
+ new_rewrite_options.scoped_allocator_opts.enable_op)
+
class DistributedCollectiveAllReduceStrategyTestWithChief(
CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
@@ -352,10 +373,6 @@
cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0, has_chief=True)
- def setUp(self):
- super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp()
- self._run_options.experimental.collective_graph_key = 7
-
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
def testMinimizeLossGraph(self, num_gpus):
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index f3ce547..c5ce29a 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -53,11 +53,11 @@
from tensorflow.contrib.optimizer_v2 import adagrad as adagrad_v2
from tensorflow.contrib.optimizer_v2 import adam as adam_v2
from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2
+from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
-from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop
from tensorflow.python.util import tf_inspect
diff --git a/tensorflow/contrib/distribute/python/cross_device_ops_test.py b/tensorflow/contrib/distribute/python/cross_device_ops_test.py
index 00672a4..5d8690b 100644
--- a/tensorflow/contrib/distribute/python/cross_device_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_device_ops_test.py
@@ -29,6 +29,7 @@
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils
+from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
@@ -37,7 +38,6 @@
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.training import device_util
def _make_per_replica(values, devices, regroup=False):
@@ -119,7 +119,7 @@
sess.run(list(left._index.values())), list(right._index.values()))
def _testReductionAndBroadcast(self, cross_device_ops, distribution):
- devices = distribution.worker_devices
+ devices = distribution.extended.worker_devices
values = [constant_op.constant(float(d)) for d in range(len(devices))]
per_replica = _make_per_replica(values, devices)
@@ -381,27 +381,31 @@
distribution=[
combinations.NamedDistribution(
"MirroredCPU",
- lambda: mirrored_strategy.MirroredStrategy(num_gpus=0),
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=0),
required_gpus=0),
combinations.NamedDistribution(
"Mirrored1GPU",
- lambda: mirrored_strategy.MirroredStrategy(num_gpus=1),
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=1),
required_gpus=1),
combinations.NamedDistribution(
"Mirrored2GPUs",
- lambda: mirrored_strategy.MirroredStrategy(num_gpus=2),
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus_per_worker=2),
required_gpus=2),
+ # pylint: disable=g-long-lambda
combinations.NamedDistribution(
"CoreMirroredCPU",
- lambda: mirrored_strategy.CoreMirroredStrategy(num_gpus=0),
+ lambda: mirrored_strategy.CoreMirroredStrategy(
+ num_gpus_per_worker=0),
required_gpus=0),
combinations.NamedDistribution(
"CoreMirrored1GPU",
- lambda: mirrored_strategy.CoreMirroredStrategy(num_gpus=1),
+ lambda: mirrored_strategy.CoreMirroredStrategy(
+ num_gpus_per_worker=1),
required_gpus=1),
combinations.NamedDistribution(
"CoreMirrored2GPUs",
- lambda: mirrored_strategy.CoreMirroredStrategy(num_gpus=2),
+ lambda: mirrored_strategy.CoreMirroredStrategy(
+ num_gpus_per_worker=2),
required_gpus=2),
],
mode=["graph"])
diff --git a/tensorflow/contrib/distribute/python/cross_device_utils_test.py b/tensorflow/contrib/distribute/python/cross_device_utils_test.py
index 6086eba..2303a31 100644
--- a/tensorflow/contrib/distribute/python/cross_device_utils_test.py
+++ b/tensorflow/contrib/distribute/python/cross_device_utils_test.py
@@ -22,13 +22,13 @@
from tensorflow.contrib.distribute.python import combinations
from tensorflow.python.distribute import cross_device_utils
+from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
-from tensorflow.python.training import device_util
class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase):
diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py
index 264dca6..e170856 100644
--- a/tensorflow/contrib/distribute/python/estimator_integration_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py
@@ -77,12 +77,12 @@
train_input_fn = self.dataset_input_fn(
x={'x': data},
y=data,
- batch_size=batch_size // len(distribution.worker_devices),
+ batch_size=batch_size // distribution.num_replicas_in_sync,
shuffle=True)
eval_input_fn = self.dataset_input_fn(
x={'x': data},
y=data,
- batch_size=batch_size // len(distribution.worker_devices),
+ batch_size=batch_size // distribution.num_replicas_in_sync,
shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
x={'x': data}, batch_size=batch_size, shuffle=False)
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
index 3e7d5df..0f35657 100644
--- a/tensorflow/contrib/distribute/python/estimator_training_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -204,10 +204,10 @@
train_input_fn = self.dataset_input_fn(
x={"x": DATA},
y=DATA,
- batch_size=BATCH_SIZE // len(train_distribute.worker_devices),
+ batch_size=BATCH_SIZE // train_distribute.num_replicas_in_sync,
shuffle=True)
if eval_distribute:
- eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices)
+ eval_batch_size = BATCH_SIZE // eval_distribute.num_replicas_in_sync
else:
eval_batch_size = BATCH_SIZE
eval_input_fn = self.dataset_input_fn(
@@ -522,7 +522,7 @@
run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.CoreMirroredStrategy(
- num_gpus=2)))
+ num_gpus_per_worker=2)))
def test_should_run_distribute_coordinator(self):
"""Tests that should_run_distribute_coordinator return a correct value."""
@@ -546,11 +546,11 @@
config_with_train_distribute = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.CoreMirroredStrategy(
- num_gpus=2)))
+ num_gpus_per_worker=2)))
config_with_eval_distribute = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
eval_distribute=mirrored_strategy.CoreMirroredStrategy(
- num_gpus=2)))
+ num_gpus_per_worker=2)))
self.assertTrue(
dc_training.should_run_distribute_coordinator(
config_with_train_distribute))
@@ -564,7 +564,7 @@
config = run_config_lib.RunConfig(
experimental_distribute=DistributeConfig(
train_distribute=mirrored_strategy.CoreMirroredStrategy(
- num_gpus=2)))
+ num_gpus_per_worker=2)))
self.assertFalse(dc_training.should_run_distribute_coordinator(config))
def test_init_run_config_duplicate_distribute(self):
diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
index 0d7e11c..6dfd85b 100644
--- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
@@ -28,6 +28,7 @@
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.estimator import run_config
from tensorflow.python.estimator import training
from tensorflow.python.estimator.canned import dnn_linear_combined
@@ -46,7 +47,6 @@
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
-from tensorflow.python.training import distribution_strategy_context as ds_context
class KerasOptimizerV2IntegrationTest(test.TestCase, parameterized.TestCase):
@@ -83,11 +83,11 @@
train_input_fn = self.dataset_input_fn(
x={'x': data},
y=data,
- batch_size=batch_size // len(distribution.worker_devices))
+ batch_size=batch_size // distribution.num_replicas_in_sync)
eval_input_fn = self.dataset_input_fn(
x={'x': data},
y=data,
- batch_size=batch_size // len(distribution.worker_devices))
+ batch_size=batch_size // distribution.num_replicas_in_sync)
predict_input_fn = numpy_io.numpy_input_fn(
x={'x': data}, batch_size=batch_size, shuffle=False)
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 29d85fe..07027bd 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -35,6 +35,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -42,7 +43,6 @@
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop
-
_RANDOM_SEED = 1337
_TRAIN_SIZE = 200
_INPUT_SIZE = (10,)
@@ -973,6 +973,28 @@
ref_output = np.ones((160, 1), dtype=np.float32)
self.assertArrayNear(output, ref_output, 1e-1)
+ @combinations.generate(strategy_minus_tpu_combinations())
+ def testOptimizerWithCallbacks(self, distribution):
+ with self.cached_session():
+ model = get_model()
+
+ optimizer = gradient_descent_keras.SGD(0.01)
+ loss = 'mse'
+ model.compile(optimizer, loss, distribute=distribution)
+
+ dataset = get_dataset(distribution)
+
+ def schedule(_):
+ return 0.001
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
+ grouped_models = distribution.unwrap(model._grouped_model)
+ with distribution.scope():
+ for m in grouped_models:
+ self.assertAllClose(0.001, keras.backend.get_value(
+ m.optimizer.lr), atol=1e-05, rtol=1e-05)
+
class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
@@ -1090,14 +1112,14 @@
def schedule(_):
return 0.001
with self.assertRaisesRegexp(ValueError,
- 'LearningRateScheduler callback is not '
- 'supported with DistributionStrategy.'):
+ 'You must specify a Keras Optimizer V2 when '
+ 'using'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
with self.assertRaisesRegexp(ValueError,
- 'ReduceLROnPlateau callback is not '
- 'supported with DistributionStrategy.'):
+ 'You must specify a Keras Optimizer V2 when '
+ 'using'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
callbacks=[keras.callbacks.ReduceLROnPlateau()])
with self.assertRaisesRegexp(ValueError,
@@ -1247,7 +1269,7 @@
model.set_weights(initial_weights)
model.compile(
loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+ optimizer=gradient_descent_keras.SGD(0.5),
distribute=with_distribution)
training_inputs, eval_inputs, predict_inputs = (
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index e77d3d4..129b394 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -344,7 +344,7 @@
run_step()
v = all_vars[0]
- self.assertTrue(all([v is vi for vi in all_vars[1:]]))
+ self.assertTrue(all(v is vi for vi in all_vars[1:]))
weight = numpy.squeeze(self.evaluate(v))
# Our model is:
# predict = x * w
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index f743216..4a594f0 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -12,805 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Class MirroredStrategy implementing DistributionStrategy."""
+"""Contrib version of MirroredStrategy."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import contextlib
-from functools import partial
-import threading
+import functools
-from tensorflow.python import pywrap_tensorflow
-from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
-from tensorflow.python.distribute import multi_worker_util
-from tensorflow.python.distribute import reduce_util
-from tensorflow.python.distribute import shared_variable_creator
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import values
-from tensorflow.python.eager import context
-from tensorflow.python.eager import tape
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import device as tf_device
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import coordinator
-from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.util import nest
-# TODO(josh11b): Replace asserts in this file with if ...: raise ...
-
-
-@contextlib.contextmanager
-def _enter_graph(g):
- if context.executing_eagerly():
- with g.as_default(), context.eager_mode():
- yield
- else:
- with g.as_default():
- yield
-
-
-def _cpu_device(device):
- cpu_device = tf_device.DeviceSpec.from_string(device)
- cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0))
- return cpu_device.to_string()
-
-
-class _RequestedStop(Exception):
- pass
-
-
-# _call_for_each_replica and _reduce_non_distributed_value are not members of
-# MirroredStrategy so that they are generally not allowed to use anything
-# specific to MirroredStrategy and thus can be shared with other distribution
-# strategies.
-
-
-# TODO(yuefengz): maybe create a common class for those who need to call this
-# _call_for_each_replica.
-def _call_for_each_replica(distribution, fn, args, kwargs):
- """Run `fn` in separate threads, once per replica/worker device.
-
- Args:
- distribution: the DistributionStrategy object.
- fn: function to run (will be run once per device, each in its own thread).
- args: positional arguments for `fn`
- kwargs: keyword arguments for `fn`.
-
- Returns:
- Merged return value of `fn` across all replicas.
-
- Raises:
- RuntimeError: If fn() calls get_replica_context().merge_call() a different
- number of times from the available devices.
- """
- # TODO(josh11b): Add this option once we add synchronization to variable
- # creation. Until then, this is pretty unsafe to use.
- run_concurrently = False
- if not context.executing_eagerly():
- # Needed for per-thread device, etc. contexts in graph mode.
- ops.get_default_graph().switch_to_thread_local()
-
- coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
-
- shared_variable_store = {}
-
- # TODO(isaprykin): Create these threads once instead of during every run()
- # call.
- threads = []
- for index, d in enumerate(distribution.extended.worker_devices):
- variable_creator_fn = shared_variable_creator.make_fn(
- shared_variable_store, index)
- t = MirroredExtended._MirroredReplicaThread( # pylint: disable=protected-access
- distribution, coord, d, variable_creator_fn, fn,
- *values.select_device(d, args), **values.select_device(d, kwargs))
- threads.append(t)
-
- for t in threads:
- t.start()
-
- # When `fn` starts `should_run` event is set on _MirroredReplicaThread
- # (`MRT`) threads. The execution waits until
- # `MRT.has_paused` is set, which indicates that either `fn` is
- # complete or a `get_replica_context().merge_call()` is called. If `fn` is
- # complete, then `MRT.done` is set to True. Otherwise, arguments
- # of `get_replica_context().merge_call` from all paused threads are grouped
- # and the `merge_fn` is performed. Results of the
- # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
- # Each such `get_replica_context().merge_call` call returns the
- # `MRT.merge_result` for that thread when `MRT.should_run` event
- # is reset again. Execution of `fn` resumes.
-
- try:
- with coord.stop_on_exception():
- all_done = False
- while not all_done and not coord.should_stop():
- done = []
- if run_concurrently:
- for t in threads:
- t.should_run.set()
- for t in threads:
- t.has_paused.wait()
- t.has_paused.clear()
- if coord.should_stop():
- return None
- done.append(t.done)
- else:
- for t in threads:
- t.should_run.set()
- t.has_paused.wait()
- t.has_paused.clear()
- if coord.should_stop():
- return None
- done.append(t.done)
- if coord.should_stop():
- return None
- all_done = all(done)
- if not all_done:
- if any(done):
- raise RuntimeError("Some replicas made a different number of "
- "replica_context().merge_call() calls.")
- # get_replica_context().merge_call() case
- merge_args = values.regroup({t.device: t.merge_args for t in threads})
- merge_kwargs = values.regroup(
- {t.device: t.merge_kwargs for t in threads})
- # We capture the name_scope of the MRT when we call merge_fn
- # to ensure that if we have opened a name scope in the MRT,
- # it will be respected when executing the merge function. We only
- # capture the name_scope from the first MRT and assume it is
- # the same for all other MRTs.
- mtt_captured_name_scope = threads[0].captured_name_scope
- with ops.name_scope(mtt_captured_name_scope):
- merge_result = threads[0].merge_fn(distribution, *merge_args,
- **merge_kwargs)
- for t in threads:
- t.merge_result = values.select_device(t.device, merge_result)
- finally:
- for t in threads:
- t.should_run.set()
- coord.join(threads)
-
- return values.regroup({t.device: t.main_result for t in threads})
-
-
-def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
- """Reduce a non-DistributedValue `value` to `destinations`."""
- if isinstance(value, values.DistributedValues):
- raise ValueError("You are passing a `DistributedValue` to "
- "`_reduce_non_distributed_value`, which is not allowed.")
-
- # If the same value is present on all replicas then the PerReplica value will
- # be a single value. We also handle the case when `value` is a single value
- # and equal to 0.
- if value == 0:
- return 0
- # If there is only a single value and the reduce op is MEAN,
- # that value should be on all destinations.
- if reduce_op == reduce_util.ReduceOp.MEAN:
- return value
-
- cross_device_ops_lib.validate_destinations(destinations)
- # We do not support a reduce op of SUM if the value is the same across
- # all replicas. We call this as part of assign functions for MirroredVariables
- # and summing up identical values across replicas is not clearly defined.
- if (len(extended.worker_devices) != 1 or
- not cross_device_ops_lib.check_destinations(destinations)):
- raise ValueError("A non-DistributedValues value %s cannot be reduced with "
- "the given reduce op %s." % (value, reduce_op))
- # TODO(anjalisridhar): Moves these methods to a device utility file?
- devices = cross_device_ops_lib.get_devices_from(destinations)
- if len(devices) == 1:
- with ops.device(devices[0]):
- return array_ops.identity(value)
- else:
- value_updates = {}
- for d in devices:
- with ops.device(d):
- value_updates[d] = array_ops.identity(value)
- return values.Mirrored(value_updates)
-
-
-def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring
- # Figure out what collections this variable should be added to.
- # We'll add the MirroredVariable to those collections instead.
- collections = kwargs.pop("collections", None)
- if collections is None:
- collections = [ops.GraphKeys.GLOBAL_VARIABLES]
- kwargs["collections"] = []
-
- # Get synchronization value
- synchronization = kwargs.get("synchronization",
- variable_scope.VariableSynchronization.ON_WRITE)
- if synchronization == variable_scope.VariableSynchronization.NONE:
- raise ValueError("`NONE` variable synchronization mode is not "
- "supported with `Mirrored` distribution strategy. Please"
- " change the `synchronization` for variable: " +
- kwargs["name"])
- elif synchronization == variable_scope.VariableSynchronization.ON_READ:
- # Variables that are to be synced on read are replica local.
- is_replica_local = True
- kwargs["trainable"] = False
- elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
- synchronization == variable_scope.VariableSynchronization.AUTO):
- # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
- is_replica_local = False
- else:
- raise ValueError("Invalid variable synchronization mode: " +
- synchronization + " for variable: " + kwargs["name"])
-
- # Get aggregation value
- aggregation = kwargs.pop("aggregation",
- variable_scope.VariableAggregation.NONE)
- if aggregation not in (
- variable_scope.VariableAggregation.NONE,
- variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN,
- variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
- ):
- raise ValueError("Invalid variable aggregation mode: " + aggregation +
- " for variable: " + kwargs["name"])
-
- # Ignore user-specified caching device, not needed for mirrored variables.
- kwargs.pop("caching_device", None)
-
- # TODO(josh11b,apassos): It would be better if variable initialization
- # was never recorded on the tape instead of having to do this manually
- # here.
- with tape.stop_recording():
- index = real_mirrored_creator(devices, *args, **kwargs)
-
- if is_replica_local:
- result = values.ReplicaLocalVariable(
- index, index[devices[0]], aggregation)
- else:
- result = values.MirroredVariable(index, index[devices[0]], aggregation)
-
- # Add the wrapped variable to the requested collections.
- # The handling of eager mode and the global step matches
- # ResourceVariable._init_from_args().
- if not context.executing_eagerly():
- g = ops.get_default_graph()
- # If "trainable" is True, next_creator() will add the member variables
- # to the TRAINABLE_VARIABLES collection, so we manually remove
- # them and replace with the MirroredVariable. We can't set
- # "trainable" to False for next_creator() since that causes functions
- # like implicit_gradients to skip those variables.
- if kwargs.get("trainable", True):
- collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
- l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
- for v in index.values():
- if v in l:
- l.remove(v)
- g.add_to_collections(collections, result)
- elif ops.GraphKeys.GLOBAL_STEP in collections:
- ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
-
- return result
-
-
-class CoreMirroredStrategy(distribute_lib.DistributionStrategy):
- """Mirrors vars to distribute across multiple devices and machines.
-
- *** core version ***
-
- This strategy uses one replica per device and sync replication for its
- multi-GPU version.
-
- When `cluster_spec` is given by the `configure` method., it turns into the
- mulit-worker version that works on multiple workers with in-graph replication.
- Note: `configure` will be called by higher-level APIs if running in
- distributed environment.
-
- There are several important concepts for distributed TensorFlow, e.g.
- `client`, `job`, 'task', `cluster`, `in-graph replication` and
- 'synchronous training' and they have already been defined in the
- [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
- The distribution strategy inherits these concepts as well and in addition to
- that we also clarify several more concepts:
-
- * **In-graph replication**: the `client` creates a single `tf.Graph` that
- specifies tasks for devices on all workers. The `client` then creates a
- client session which will talk to the `master` service of a `worker`. Then
- the `master` will partition the graph and distribute the work to all
- participating workers.
- * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
- physical machine. We will have multiple `worker`s with different `task`
- index. They all do similar things except for one worker checkpointing model
- variables, writing summaries, etc. in addition to its ordinary work.
-
- The multi-worker version of this class maps one replica to one device on a
- worker. It mirrors all model variables on all replicas. For example, if you
- have two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of
- the model variables on these 8 GPUs. Then like in MirroredStrategy, each
- replica performs their computation with their own copy of variables unless in
- cross-replica model where variable or tensor reduction happens.
-
- Args:
- devices: a list of device strings.
- num_gpus: number of GPUs. For local training, either specify `devices` or
- `num_gpus`. In distributed training, this must be specified as number of
- GPUs on each worker.
- num_gpus_per_worker: number of GPUs per worker. This is the same as
- `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be
- specified.
- cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
- set, the `configure` method will try to find the best one.
- auto_shard_dataset: whether to auto-shard the dataset when there are
- multiple workers.
- """
-
- def __init__(self,
- devices=None,
- num_gpus=None,
- num_gpus_per_worker=None,
- cross_device_ops=None,
- auto_shard_dataset=False):
- extended = CoreMirroredExtended(
- self, devices, num_gpus, num_gpus_per_worker,
- cross_device_ops, auto_shard_dataset)
- super(CoreMirroredStrategy, self).__init__(extended)
-
-
-class CoreMirroredExtended(distribute_lib.DistributionStrategyExtended):
- """Implementation of CoreMirroredStrategy."""
-
- def __init__(self,
- container_strategy,
- devices=None,
- num_gpus=None,
- num_gpus_per_worker=None,
- cross_device_ops=None,
- auto_shard_dataset=False):
- super(CoreMirroredExtended, self).__init__(container_strategy)
- self._cross_device_ops = cross_device_ops
- self._auto_shard_dataset = auto_shard_dataset
- # Remember num GPUs which might be needed by `configure` method.
- if num_gpus is not None and num_gpus_per_worker is not None:
- raise ValueError(
- "You cannot specify both `num_gpus` and `num_gpus_per_worker`.")
- if num_gpus is not None:
- self._num_gpus = num_gpus
- else:
- self._num_gpus = num_gpus_per_worker
-
- self._initialize_local(self._num_gpus, devices)
-
- def _initialize_local(self, num_gpus, devices):
- """Initializes the object for local training."""
- self._cluster_spec = None
- # Convert `num_gpus` into `devices`, shouldn't specify both.
- if devices is None:
- if num_gpus is None:
- num_gpus = context.num_gpus()
- if num_gpus == 0:
- devices = ["/device:CPU:0"]
- else:
- devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
- elif num_gpus is not None:
- raise ValueError("Must only specify one of `devices` and `num_gpus`.")
- self._num_gpus = num_gpus
- # TODO(yuefengz): consider setting the default device.
-
- assert devices, "Must specify at least one device."
- assert len(set(devices)) == len(devices), (
- "No duplicates allowed in `devices` argument.")
- # TODO(josh11b): Require at least 2 devices?
- self._devices = [device_util.resolve(d) for d in devices]
- self._canonical_device_set = set(self._devices)
- self._device_index = values.PerReplica(
- {d: i for i, d in enumerate(devices)})
-
- def _initialize_multi_worker(self, num_gpus, cluster_spec):
- """Initializes the object for multi-worker training."""
- cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
- self._cluster_spec = cluster_spec
-
- self._workers = []
- for job in ["chief", "worker"]:
- for task in range(len(cluster_spec.as_dict().get(job, []))):
- self._workers.append("/job:%s/task:%d" % (job, task))
-
- if num_gpus is None:
- raise ValueError("`num_gpus` is required if `cluster_spec` is given.")
- if num_gpus > 0:
- self._worker_devices = [
- (worker, [
- device_util.canonicalize(worker + "/device:GPU:%d" % gpu)
- for gpu in range(num_gpus)
- ]) for worker in self._workers
- ]
- else:
- self._worker_devices = [
- (worker, [device_util.canonicalize(worker, "/device:CPU:0")])
- for worker in self._workers
- ]
-
- devices = nest.flatten([l for _, l in self._worker_devices])
-
- # Setting `_default_device` will add a device scope in the
- # distribution.scope. We set the default device to the first worker. When
- # users specify device under distribution.scope by
- # with tf.device("/cpu:0"):
- # ...
- # their ops will end up on the cpu device of its first worker, e.g.
- # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
- self._default_device = self._workers[0]
-
- assert devices, "Must specify at least one device."
- assert len(set(devices)) == len(devices), (
- "No duplicates allowed in `devices` argument.")
- # TODO(josh11b): Require at least 2 devices?
- self._devices = [device_util.resolve(d) for d in devices]
- self._canonical_device_set = set(self._devices)
- self._device_index = values.PerReplica(
- {d: i for i, d in enumerate(devices)})
-
- def _create_variable(self, next_creator, *args, **kwargs):
- """Create a mirrored variable. See `DistributionStrategy.scope`."""
- colocate_with = kwargs.pop("colocate_with", None)
- devices = self._get_devices_from(colocate_with)
-
- def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
- index = {}
- for i, d in enumerate(devices):
- with ops.device(d):
- if i > 0:
- # Give replicas meaningful distinct names:
- var0name = index[devices[0]].name.split(":")[0]
- # We append a / to variable names created on replicas with id > 0 to
- # ensure that we ignore the name scope and instead use the given
- # name as the absolute name of the variable.
- kwargs["name"] = "%s/replica_%d/" % (var0name, i)
- # Initialize replicas with the same value:
- def initial_value_fn(device=d):
- if context.executing_eagerly():
- init_value = index[devices[0]].value()
- return array_ops.identity(init_value)
- else:
- with ops.device(device):
- init_value = index[devices[0]].initial_value
- return array_ops.identity(init_value)
- kwargs["initial_value"] = initial_value_fn
- with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- # Don't record operations (e.g. other variable reads) during
- # variable creation.
- with tape.stop_recording():
- v = next_creator(*args, **kwargs)
- assert not isinstance(v, values.DistributedVariable)
- index[d] = v
- return index
-
- return _create_mirrored_variable(devices, _real_mirrored_creator, *args,
- **kwargs)
-
- def _distribute_dataset(self, dataset_fn):
- if self._cluster_spec:
- return values.MultiWorkerDataset(
- partial(self._call_dataset_fn, dataset_fn), self._worker_devices,
- auto_shard=self._auto_shard_dataset)
- else:
- return values.PerReplicaDataset(
- self._call_dataset_fn(dataset_fn), self._devices)
-
- def _make_dataset_iterator(self, dataset):
- if self._cluster_spec:
- worker_device_pairs = self._worker_devices
- else:
- worker_device_pairs = [("/job:localhost", self._devices)]
- return values.DatasetIterator(dataset, worker_device_pairs,
- self._num_replicas_in_sync)
-
- def _make_input_fn_iterator(
- self,
- input_fn,
- replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
- input_contexts = []
- if self._cluster_spec:
- num_workers = len(self._worker_devices)
- worker_device_pairs = self._worker_devices
- else:
- num_workers = 1
- worker_device_pairs = [("/job:localhost", self._devices)]
- for i in range(num_workers):
- input_contexts.append(distribute_lib.InputContext(
- num_input_pipelines=num_workers,
- input_pipeline_id=i,
- num_replicas_in_sync=self._num_replicas_in_sync))
- return values.InputFunctionIterator(
- input_fn, worker_device_pairs, input_contexts)
-
- # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
- def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
- initial_loop_values=None):
- if initial_loop_values is None:
- initial_loop_values = {}
- initial_loop_values = nest.flatten(initial_loop_values)
-
- ctx = values.MultiStepContext()
- def body(i, *args):
- """A wrapper around `fn` to create the while loop body."""
- del args
- fn_inputs = iterator.get_next()
- if not isinstance(fn_inputs, tuple):
- fn_inputs = (fn_inputs,)
- fn_result = fn(ctx, fn_inputs)
- for (name, output) in ctx.last_step_outputs.items():
- # Convert all outputs to tensors, potentially from `DistributedValues`.
- ctx.last_step_outputs[name] = self._unwrap(output)
- flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
- with ops.control_dependencies([fn_result]):
- return [i + 1] + flat_last_step_outputs
-
- # We capture the control_flow_context at this point, before we run `fn`
- # inside a while_loop. This is useful in cases where we might need to exit
- # these contexts and get back to the outer context to do some things, for
- # e.g. create an op which should be evaluated only once at the end of the
- # loop on the host. One such usage is in creating metrics' value op.
- self._outer_control_flow_context = (
- ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
-
- cond = lambda i, *args: i < iterations
- i = constant_op.constant(0)
- loop_result = control_flow_ops.while_loop(
- cond, body, [i] + initial_loop_values, name="",
- parallel_iterations=1, back_prop=False, swap_memory=False,
- return_same_structure=True)
- del self._outer_control_flow_context
-
- ctx.run_op = control_flow_ops.group(loop_result)
-
- # Convert the last_step_outputs from a list to the original dict structure
- # of last_step_outputs.
- last_step_tensor_outputs = loop_result[1:]
- last_step_tensor_outputs_dict = nest.pack_sequence_as(
- ctx.last_step_outputs, last_step_tensor_outputs)
-
- for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
- output = last_step_tensor_outputs_dict[name]
- # For outputs that have already been reduced, wrap them in a Mirrored
- # container, else in a PerReplica container.
- if reduce_op is None:
- last_step_tensor_outputs_dict[name] = values.regroup(
- {d: t for d, t in zip(self._devices, output)}, values.PerReplica)
- else:
- assert len(output) == 1
- last_step_tensor_outputs_dict[name] = output[0]
-
- ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
- return ctx
-
- def _broadcast_to(self, tensor, destinations):
- # This is both a fast path for Python constants, and a way to delay
- # converting Python values to a tensor until we know what type it
- # should be converted to. Otherwise we have trouble with:
- # global_step.assign_add(1)
- # since the `1` gets broadcast as an int32 but global_step is int64.
- if isinstance(tensor, (float, int)):
- return tensor
- # TODO(josh11b): In eager mode, use one thread per device, or async mode.
- return self._get_cross_device_ops().broadcast(
- tensor, destinations or self._devices)
-
- def _call_for_each_replica(self, fn, args, kwargs):
- return _call_for_each_replica(self._container_strategy(), fn, args, kwargs)
-
- def _configure(self,
- session_config=None,
- cluster_spec=None,
- task_type=None,
- task_id=None):
- del task_type, task_id
-
- if session_config:
- session_config.isolate_session_state = True
-
- if cluster_spec:
- self._initialize_multi_worker(self._num_gpus, cluster_spec)
-
- if self._cross_device_ops is None:
- if self._cluster_spec:
- # It currently cannot detect the toplogy of remote workers. So we
- # hard-code the multi-worker all-reduce algorithm for now.
- if len(self._workers) == 1:
- # The default is "nccl".
- self._cross_device_ops = (
- cross_device_ops_lib.AllReduceCrossDeviceOps())
- else:
- # The default is hierarchical reduce and broadcast.
- self._cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce(
- self._workers, self._num_gpus)
- else:
- self._cross_device_ops = cross_device_ops_lib.choose_the_best(
- self._devices, session_config=session_config)
-
- def _get_cross_device_ops(self):
- if self._cross_device_ops is None:
- self._cross_device_ops = (
- cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps())
- return self._cross_device_ops
-
- def _reduce_to(self, reduce_op, value, destinations):
- assert not isinstance(value, values.Mirrored)
- if not isinstance(value, values.DistributedValues):
- # This function handles reducing values that are not PerReplica or
- # Mirrored values. For example, the same value could be present on all
- # replicas in which case `value` would be a single value or value could
- # be 0.
- return _reduce_non_distributed_value(self, reduce_op, value,
- destinations)
- return self._get_cross_device_ops().reduce(
- reduce_op, value, destinations=destinations)
-
- def _batch_reduce_to(self, reduce_op, value_destination_pairs):
- return self._get_cross_device_ops().batch_reduce(reduce_op,
- value_destination_pairs)
-
- def _update(self, var, fn, args, kwargs, group):
- # TODO(josh11b): In eager mode, use one thread per device.
- assert isinstance(var, values.DistributedVariable)
- updates = {}
- for d, v in var._index.items(): # pylint: disable=protected-access
- name = "update_%d" % self._device_index.get(d)
- with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
- # If args and kwargs are not mirrored, the value is returned as is.
- updates[d] = fn(v,
- *values.select_device_mirrored(d, args),
- **values.select_device_mirrored(d, kwargs))
- return values.update_regroup(self, updates, group)
-
- def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
- assert isinstance(colocate_with, list)
- # TODO(josh11b): In eager mode, use one thread per device.
- updates = {}
- for d in colocate_with:
- name = "update_%d" % self._device_index.get(d)
- with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
- updates[d] = fn(*values.select_device_mirrored(d, args),
- **values.select_device_mirrored(d, kwargs))
- return values.update_regroup(self, updates, group)
-
- def read_var(self, replica_local_var):
- """Read the aggregate value of a replica-local variable."""
- if isinstance(replica_local_var, values.ReplicaLocalVariable):
- return replica_local_var._get_cross_replica() # pylint: disable=protected-access
- assert isinstance(replica_local_var, values.Mirrored)
- return array_ops.identity(replica_local_var.get())
-
- def _unwrap(self, val):
- if isinstance(val, values.DistributedValues):
- # Return in a deterministic order.
- if set(val.devices) == self._canonical_device_set:
- return [val.get(device=d) for d in self._devices]
- return [val.get(device=d) for d in sorted(val.devices)]
- return [val]
-
- def value_container(self, val):
- return values.value_container(val)
-
- @property
- def _num_replicas_in_sync(self):
- return len(self._devices)
-
- @property
- def worker_devices(self):
- # Make a copy to prevent users from accidentally mutating our copy.
- return list(self._devices)
-
- @property
- def parameter_devices(self):
- return list(self._devices)
-
- @property
- def experimental_between_graph(self):
- return False
-
- @property
- def experimental_should_init(self):
- return True
-
- @property
- def should_checkpoint(self):
- return True
-
- @property
- def should_save_summary(self):
- return True
-
- def non_slot_devices(self, var_list):
- del var_list
- return list(self._devices)
-
- def _get_devices_from(self, colocate_with=None):
- if colocate_with is None:
- return self._devices
- else:
- return cross_device_ops_lib.get_devices_from(colocate_with)
-
- class _MirroredReplicaThread(threading.Thread):
- """A thread that runs() a function on a device."""
-
- def __init__(self, dist, coord, device, variable_creator_fn, fn, *args,
- **kwargs):
- super(CoreMirroredExtended._MirroredReplicaThread, self).__init__() # pylint: disable=protected-access
- self.coord = coord
- self.distribution = dist
- self.device = device
- self.replica_id = dist.worker_devices.index(device)
- self.variable_creator_fn = variable_creator_fn
- # State needed to run and return the results of `fn`.
- self.main_fn = fn
- self.main_args = args
- self.main_kwargs = kwargs
- self.main_result = None
- self.done = False
- # State needed to run the next merge_call() (if any) requested via
- # ReplicaContext.
- self.merge_fn = None
- self.merge_args = None
- self.merge_kwargs = None
- self.merge_result = None
- self.captured_name_scope = None
- # We use a thread.Event for the main thread to signal when this
- # thread should start running (`should_run`), and another for
- # this thread to transfer control back to the main thread
- # (`has_paused`, either when it gets to a
- # `get_replica_context().merge_call` or when `fn` returns). In
- # either case the event starts cleared, is signaled by calling
- # set(). The receiving thread waits for the signal by calling
- # wait() and then immediately clearing the event using clear().
- self.should_run = threading.Event()
- self.has_paused = threading.Event()
- # These fields have to do with inheriting various contexts from the
- # parent thread:
- # pylint: disable=protected-access
- self.context_mode = context.context()._eager_context.mode
- if not context.context()._context_handle:
- context.context()._initialize_handle_and_devices()
- self.context_device_policy = (
- pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
- context.context()._context_handle))
- self.graph = ops.get_default_graph()
- self._variable_creator_stack = self.graph._variable_creator_stack[:]
- self._captured_var_scope = variable_scope.get_variable_scope()
- # Adding a "/" at end lets us re-enter this scope later.
- self._name_scope = self.graph.get_name_scope()
- if self._name_scope:
- self._name_scope += "/"
- if self.replica_id > 0:
- if not self._name_scope:
- self._name_scope = ""
- self._name_scope += "replica_%d/" % self.replica_id
-
- def run(self):
- # pylint: disable=protected-access
- self.graph._variable_creator_stack = self._variable_creator_stack
- self.should_run.wait()
- self.should_run.clear()
- try:
- if self.coord.should_stop():
- return
- with self.coord.stop_on_exception(), \
- context.context()._mode(self.context_mode), \
- context.context().device_policy(self.context_device_policy), \
- _enter_graph(self.graph), \
- MirroredReplicaContext(self.distribution, constant_op.constant(
- self.replica_id, dtypes.int32)), \
- ops.device(self.device), \
- ops.name_scope(self._name_scope), \
- variable_scope.variable_scope(
- self._captured_var_scope, reuse=self.replica_id > 0), \
- variable_scope.variable_creator_scope(self.variable_creator_fn):
- self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
- self.done = True
- finally:
- self.has_paused.set()
+# pylint: disable=protected-access,invalid-name
+_call_for_each_replica = mirrored_strategy._call_for_each_replica
+_reduce_non_distributed_value = mirrored_strategy._reduce_non_distributed_value
+_create_mirrored_variable = mirrored_strategy._create_mirrored_variable
+CoreMirroredStrategy = mirrored_strategy.MirroredStrategy
+CoreMirroredExtended = mirrored_strategy.MirroredExtended
+# pylint: enable=protected-access,invalid-name
class MirroredStrategy(distribute_lib.DistributionStrategy):
@@ -873,26 +95,29 @@
auto_shard_dataset=False,
cross_tower_ops=None):
assert not (cross_device_ops and cross_tower_ops)
- extended = MirroredExtended(
- self, devices, num_gpus, num_gpus_per_worker,
- cross_device_ops or cross_tower_ops, auto_shard_dataset)
+ if num_gpus is not None and num_gpus_per_worker is not None:
+ raise ValueError(
+ "You cannot specify both `num_gpus` and `num_gpus_per_worker`.")
+ if num_gpus is None:
+ num_gpus = num_gpus_per_worker
+ extended = MirroredExtended(self, devices, num_gpus,
+ cross_device_ops or cross_tower_ops,
+ auto_shard_dataset)
super(MirroredStrategy, self).__init__(extended)
class MirroredExtended(CoreMirroredExtended):
"""Implementation of (contrib) MirroredStrategy."""
- # pylint: disable=useless-super-delegation
def __init__(self,
container_strategy,
devices=None,
- num_gpus=None,
num_gpus_per_worker=None,
cross_device_ops=None,
auto_shard_dataset=False):
super(MirroredExtended, self).__init__(
- container_strategy, devices, num_gpus, num_gpus_per_worker,
- cross_device_ops, auto_shard_dataset)
+ container_strategy, devices, num_gpus_per_worker, cross_device_ops)
+ self._auto_shard_dataset = auto_shard_dataset
def _make_dataset_iterator(self, dataset):
"""Make iterator from dataset without splitting the batch.
@@ -909,39 +134,21 @@
if self._cluster_spec:
worker_device_pairs = self._worker_devices
else:
- worker_device_pairs = [("/job:localhost", self._devices)]
+ worker = device_util.canonicalize("/device:CPU:0")
+ worker_device_pairs = [(worker, self._devices)]
return values.DatasetIterator(dataset, worker_device_pairs)
+ def _distribute_dataset(self, dataset_fn):
+ if self._cluster_spec:
+ return values.MultiWorkerDataset(
+ functools.partial(self._call_dataset_fn, dataset_fn),
+ self._worker_devices,
+ auto_shard=self._auto_shard_dataset)
+ else:
+ return values.PerReplicaDataset(
+ self._call_dataset_fn(dataset_fn), self._devices)
-class MirroredReplicaContext(distribute_lib.ReplicaContext):
- """ReplicaContext used in MirroredStrategy.call_for_each_replica().
-
- Opened in `_MirroredReplicaThread`, to allow the user to invoke
- `MirroredStrategy`'s specific implementation of `merge_call()`,
- which works by delegating the function and its arguments to
- the main thread (the one that invoked
- `MirroredStrategy.call_for_each_replica()`).
- """
-
- def _merge_call(self, fn, args, kwargs):
- """Delegate to the main thread to actually perform merge_call()."""
- t = threading.current_thread() # a _MirroredReplicaThread
- t.merge_fn = fn
- t.merge_args = args
- t.merge_kwargs = kwargs
- t.captured_name_scope = t.graph.get_name_scope()
- # Adding a "/" at end lets us re-enter this scope later.
- if t.captured_name_scope:
- t.captured_name_scope += "/"
- t.has_paused.set()
- t.should_run.wait()
- t.should_run.clear()
- if t.coord.should_stop():
- raise _RequestedStop()
- return t.merge_result
-
+ # TODO(priyag): Delete this once all strategies use global batch size.
@property
- def devices(self):
- distribute_lib.require_replica_context(self)
- replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
- return [self._distribution_strategy.worker_devices[replica_id]]
+ def _global_batch_size(self):
+ return False
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 9fd4cca..b304f63 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -27,7 +27,10 @@
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop
@@ -47,8 +50,6 @@
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import device_util
-from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import server_lib
@@ -90,11 +91,11 @@
return list(range(replica_id))
with distribution.scope(), self.assertRaises(AssertionError):
- distribution.call_for_each_replica(run_fn)
+ distribution.extended.call_for_each_replica(run_fn)
def testReduceToCpu(self, distribution):
with distribution.scope():
- result = distribution.call_for_each_replica(_replica_id)
+ result = distribution.extended.call_for_each_replica(_replica_id)
reduced = distribution.reduce(
reduce_util.ReduceOp.SUM,
result,
@@ -114,7 +115,7 @@
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
iterator = distribution.make_input_fn_iterator(input_fn)
- self._test_input_fn_iterator(iterator, distribution.worker_devices,
+ self._test_input_fn_iterator(iterator, distribution.extended.worker_devices,
expected_values)
def testGlobalStepUpdate(self, distribution):
@@ -150,7 +151,7 @@
mode=["graph", "eager"]))
def testReduceToMultipleDestinations(self, distribution):
with distribution.scope():
- reduced = distribution.reduce(
+ reduced = distribution.extended.reduce_to(
reduce_util.ReduceOp.SUM,
1.0,
destinations=["/device:CPU:0", "/device:GPU:0"])
@@ -204,7 +205,7 @@
with context.graph_mode(), \
distribution.scope(), \
variable_scope.variable_creator_scope(main_thread_creator):
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
result = distribution.unwrap(result)
expected = ["main_thread:thread_0", "main_thread:thread_1"]
self.assertEqual(expected, result)
@@ -221,13 +222,13 @@
def model_fn():
# This variable should be created only once across the threads because of
# special variable_creator functions used by
- # `distribution.call_for_each_replica`.
+ # `distribution.extended.call_for_each_replica`.
v = variable_scope.variable(1.0, name="foo")
ds_context.get_replica_context().merge_call(lambda _: _)
return v
with distribution.scope():
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(result, values.MirroredVariable)
self.assertEqual("foo:0", result.name)
@@ -238,7 +239,7 @@
return v
with distribution.scope():
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(result, values.MirroredVariable)
# Default name of "Variable" will be used.
self.assertEqual("Variable:0", result.name)
@@ -252,7 +253,7 @@
return vs
with distribution.scope():
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
for i, v in enumerate(result):
self.assertIsInstance(v, values.MirroredVariable)
self.assertEqual("foo" + str(i) + ":0", v.name)
@@ -268,7 +269,7 @@
return vs
with distribution.scope():
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
for v in result:
self.assertIsInstance(v, values.MirroredVariable)
self.assertEqual(4, len(result))
@@ -285,7 +286,7 @@
return v
with distribution.scope():
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(result, values.MirroredVariable)
# The resulting mirrored variable will use the name from the first device.
self.assertEqual("foo_0:0", result.name)
@@ -316,7 +317,8 @@
features = iterator.get_next()
with distribution.scope():
- result = distribution.call_for_each_replica(model_fn, args=(features,))
+ result = distribution.extended.call_for_each_replica(
+ model_fn, args=(features,))
suffixes = ["", "_1", "_2"]
for (kernel, bias), suffix in zip(result, suffixes):
self.assertIsInstance(kernel, values.MirroredVariable)
@@ -348,7 +350,7 @@
v = variable_scope.variable(1.0, name="var-main0")
self.assertEqual("var-main0:0", v.name)
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(4, len(result))
v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
@@ -385,7 +387,7 @@
v = variable_scope.get_variable("var-main0", [1])
self.assertEqual("main/var-main0:0", v.name)
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(4, len(result))
v0, v1, v2, v3 = result
self.assertIsInstance(v0, values.MirroredVariable)
@@ -418,15 +420,15 @@
devices = ["/device:GPU:0", "/device:CPU:0"]
with distribution.scope():
- v0, v1 = distribution.call_for_each_replica(create_fn)
+ v0, v1 = distribution.extended.call_for_each_replica(create_fn)
self.evaluate(v0.initializer)
self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
- self.assertEqual(2.0, self.evaluate(distribution.read_var(v0)))
+ self.assertEqual(2.0, self.evaluate(distribution.extended.read_var(v0)))
self.evaluate(v1.initializer)
self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
- self.assertEqual(3.0, self.evaluate(distribution.read_var(v1)))
+ self.assertEqual(3.0, self.evaluate(distribution.extended.read_var(v1)))
def replica_id_plus_one():
return math_ops.cast(_replica_id() + 1, dtype=dtypes.float32)
@@ -437,7 +439,8 @@
update1 = v1.assign_add(7.0 * replica_id_plus_one())
return update0, update1
- update0a, update1a = distribution.call_for_each_replica(update_member_fn)
+ update0a, update1a = distribution.extended.call_for_each_replica(
+ update_member_fn)
# Update "sync on read" variable.
self.evaluate(distribution.group(update0a))
@@ -446,7 +449,8 @@
# so device[1] can end up with a different value.
self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1])))
# Always reads from device 0.
- self.assertEqual(2.0 + 5.0, self.evaluate(distribution.read_var(v0)))
+ self.assertEqual(2.0 + 5.0, self.evaluate(
+ distribution.extended.read_var(v0)))
# Update "sync on write" variable.
self.evaluate(distribution.group(update1a))
@@ -454,7 +458,8 @@
# Writes are synchronized for v1, only the argument to assign_add on
# device[0] is used.
self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
- self.assertEqual(3.0 + 7.0, self.evaluate(distribution.read_var(v1)))
+ self.assertEqual(3.0 + 7.0, self.evaluate(
+ distribution.extended.read_var(v1)))
# Update using state_ops.assign_add global function.
def update_state_ops_fn():
@@ -462,7 +467,7 @@
update1 = state_ops.assign_add(v1, 13.0 * replica_id_plus_one())
return update0, update1
- update0b, update1b = distribution.call_for_each_replica(
+ update0b, update1b = distribution.extended.call_for_each_replica(
update_state_ops_fn)
self.evaluate(distribution.group(update0b))
@@ -470,14 +475,14 @@
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1])))
self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(
- distribution.read_var(v0)))
+ distribution.extended.read_var(v0)))
# Update "sync on write" variable.
self.evaluate(distribution.group(update1b))
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(
- distribution.read_var(v1)))
+ distribution.extended.read_var(v1)))
def testNoneSynchronizationWithGetVariable(self, distribution):
with distribution.scope():
@@ -540,7 +545,7 @@
"/device:GPU:0": "bar"
})
with self.assertRaises(RuntimeError):
- _ = distribution.call_for_each_replica(model_fn, args=(names,))
+ _ = distribution.extended.call_for_each_replica(model_fn, args=(names,))
def testReplicaLocalVariable(self, distribution):
all_v_sum = {}
@@ -575,7 +580,7 @@
with distribution.scope():
# Create "sum" and "mean" versions of ReplicaLocalVariables.
ret_ops, ret_v_sum, ret_v_mean, regrouped_sum, regrouped_mean = (
- distribution.call_for_each_replica(model_fn))
+ distribution.extended.call_for_each_replica(model_fn))
# Should see the same wrapping instance in all replicas.
self.assertIs(all_v_sum[0], ret_v_sum)
self.assertIs(all_v_mean[0], ret_v_mean)
@@ -609,9 +614,9 @@
# applying the reduction across all replicas (whether you use
# read_var(), get(), or nothing).
self.assertEqual(expected_sum, self.evaluate(
- distribution.read_var(ret_v_sum)))
+ distribution.extended.read_var(ret_v_sum)))
self.assertEqual(expected_mean, self.evaluate(
- distribution.read_var(ret_v_mean)))
+ distribution.extended.read_var(ret_v_mean)))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
@@ -631,7 +636,7 @@
return outputs
with context.graph_mode(), distribution.scope():
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
# Two variables are created by the RNN layer.
self.assertEqual(2, len(result))
for v in result:
@@ -652,7 +657,7 @@
return var.assign(value)
with distribution.scope():
- ret_v_sum = distribution.call_for_each_replica(model_fn)
+ ret_v_sum = distribution.extended.call_for_each_replica(model_fn)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
@@ -663,7 +668,8 @@
self.assertEqual(2.0, self.evaluate(ret_v_sum))
# Apply updates.
- update_ops = distribution.update(ret_v_sum, update, 5.0, grouped=False)
+ update_ops = distribution.extended.update(
+ ret_v_sum, update, args=(5.0,), group=False)
self.evaluate(update_ops)
# Assert that the aggregated value of the replica local vars is the sum
# of the individual values after running the update ops.
@@ -691,7 +697,7 @@
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(2, len(result))
for v, name in zip(result, ["a", "b"]):
self.assertIsInstance(v, values.DistributedValues)
@@ -708,7 +714,7 @@
return a, b
with context.graph_mode(), distribution.scope():
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
self.assertEqual(2, len(result))
for v, name in zip(result, ["a", "b"]):
self.assertIsInstance(v, values.DistributedValues)
@@ -734,7 +740,7 @@
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
a = variable_scope.variable(1.0, name="a")
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
@@ -763,7 +769,7 @@
with context.graph_mode(), distribution.scope():
with ops.name_scope("main"):
a = variable_scope.get_variable("a", [1])
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
result_b = result[0]
result_c = result[1]
self.assertIsInstance(result_b, values.DistributedValues)
@@ -805,7 +811,7 @@
return v
with distribution.scope():
- result = distribution.call_for_each_replica(model_fn)
+ result = distribution.extended.call_for_each_replica(model_fn)
self.assertIsInstance(result, values.MirroredVariable)
self.assertEqual("foo:0", result.name)
@@ -828,7 +834,7 @@
return v
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
@@ -839,7 +845,7 @@
ValueError, "You must specify an aggregation method to update a "
"MirroredVariable in Replica Context."):
self.evaluate(distribution.unwrap(
- distribution.call_for_each_replica(model_fn)))
+ distribution.extended.call_for_each_replica(model_fn)))
def testAssignMirroredVarReplicaContextWithSum(self, distribution):
# Test that we don't reduce a non-per-replica value with the "sum"
@@ -850,7 +856,7 @@
return v
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
@@ -861,14 +867,14 @@
ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
"with the given reduce op ReduceOp.SUM."):
self.evaluate(distribution.unwrap(
- distribution.call_for_each_replica(model_fn)))
+ distribution.extended.call_for_each_replica(model_fn)))
def testAssignMirroredVarCrossDeviceContext(self, distribution):
def var_fn():
return variable_scope.variable(1.0, name="foo")
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
@@ -881,7 +887,7 @@
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
@@ -893,7 +899,7 @@
return mirrored_var.assign(value)
self.evaluate(distribution.unwrap(
- distribution.call_for_each_replica(model_fn)))
+ distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(0.5, self.evaluate(mirrored_var))
def testAssignMirroredVarReplicaContextWithSingleValue(self, distribution):
@@ -902,7 +908,7 @@
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
@@ -911,7 +917,7 @@
return mirrored_var.assign(5.0)
self.evaluate(distribution.unwrap(
- distribution.call_for_each_replica(model_fn)))
+ distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(5.0, self.evaluate(mirrored_var))
def testAssignAddMirroredVarCrossDeviceContext(self, distribution):
@@ -919,7 +925,7 @@
return variable_scope.variable(1.0, name="foo")
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
@@ -942,7 +948,7 @@
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
@@ -954,7 +960,7 @@
return mirrored_var.assign_add(value)
self.evaluate(distribution.unwrap(
- distribution.call_for_each_replica(model_fn)))
+ distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(1.5, self.evaluate(mirrored_var))
def testAssignAddMirroredVarReplicaContextWithSingleValue(self, distribution):
@@ -963,7 +969,7 @@
1.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(1.0, self.evaluate(mirrored_var))
@@ -972,7 +978,7 @@
return mirrored_var.assign_add(5.0)
self.evaluate(distribution.unwrap(
- distribution.call_for_each_replica(model_fn)))
+ distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(6.0, self.evaluate(mirrored_var))
def testAssignSubMirroredVarCrossDeviceContext(self, distribution):
@@ -980,7 +986,7 @@
return variable_scope.variable(5.0, name="foo")
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
@@ -995,7 +1001,7 @@
5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
@@ -1007,7 +1013,7 @@
return mirrored_var.assign_sub(value)
self.evaluate(distribution.unwrap(
- distribution.call_for_each_replica(model_fn)))
+ distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(4.5, self.evaluate(mirrored_var))
def testAssignSubMirroredVarReplicaContextWithSingleValue(self, distribution):
@@ -1016,7 +1022,7 @@
5.0, name="foo", aggregation=variable_scope.VariableAggregation.MEAN)
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(5.0, self.evaluate(mirrored_var))
@@ -1025,7 +1031,7 @@
return mirrored_var.assign_sub(1.0)
self.evaluate(distribution.unwrap(
- distribution.call_for_each_replica(model_fn)))
+ distribution.extended.call_for_each_replica(model_fn)))
self.assertEqual(4.0, self.evaluate(mirrored_var))
@@ -1045,7 +1051,7 @@
return v
with distribution.scope():
- mirrored_var = distribution.call_for_each_replica(var_fn)
+ mirrored_var = distribution.extended.call_for_each_replica(var_fn)
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.assertFalse(self.evaluate(mirrored_var.is_initialized()))
self.evaluate(mirrored_var.initializer)
@@ -1064,7 +1070,8 @@
return v_sum
with distribution.scope():
- replica_local_var = distribution.call_for_each_replica(model_fn)
+ replica_local_var = distribution.extended.call_for_each_replica(
+ model_fn)
self.assertTrue(isinstance(replica_local_var,
values.ReplicaLocalVariable))
self.assertFalse(self.evaluate(replica_local_var.is_initialized()))
@@ -1088,7 +1095,7 @@
return v_sum
with distribution.scope():
- replica_local_var = distribution.call_for_each_replica(model_fn)
+ replica_local_var = distribution.extended.call_for_each_replica(model_fn)
self.assertTrue(isinstance(replica_local_var,
values.ReplicaLocalVariable))
self.evaluate(variables.global_variables_initializer())
@@ -1116,7 +1123,7 @@
return v_sum
with distribution.scope():
- replica_local_var = distribution.call_for_each_replica(model_fn)
+ replica_local_var = distribution.extended.call_for_each_replica(model_fn)
self.assertTrue(isinstance(replica_local_var,
values.ReplicaLocalVariable))
self.evaluate(variables.global_variables_initializer())
@@ -1181,7 +1188,7 @@
mock_model = MockModel(two_variables)
self.evaluate(variables.global_variables_initializer())
- result = distribution.call_for_each_replica(
+ result = distribution.extended.call_for_each_replica(
model_fn, args=[mock_model] + inputs)
for device in devices:
device_result = values.select_device(device, result)
@@ -1194,8 +1201,9 @@
# call_for_each has one trace per device. To check that the expected set
# of variables was accessed on each trace, we first retrieve each
# device-specific graph function.
- per_replica_graph_functions = distribution.call_for_each_replica(
- defun.get_concrete_function, args=[mock_model] + inputs)
+ per_replica_graph_functions = (
+ distribution.extended.call_for_each_replica(
+ defun.get_concrete_function, args=[mock_model] + inputs))
for device in devices:
graph_function = per_replica_graph_functions.get(device=device)
self.assertEqual(set(mock_model.variables),
@@ -1281,7 +1289,7 @@
gradients_fn = backprop.implicit_grad(loss_fn)
gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)
- grads_and_vars = distribution.call_for_each_replica(
+ grads_and_vars = distribution.extended.call_for_each_replica(
gradients_fn, args=(None,))
optimizer = gradient_descent.GradientDescentOptimizer(0.25)
@@ -1297,21 +1305,23 @@
self.assertAllEqual([0.5], updated_var_values[1])
-@combinations.generate(combinations.combine(
- distribution=[
- combinations.NamedDistribution(
- "Mirrored",
- # pylint: disable=g-long-lambda
- lambda: mirrored_strategy.CoreMirroredStrategy(
- num_gpus=context.num_gpus()),
- required_gpus=1),
- combinations.NamedDistribution(
- "CoreMirrored",
- # pylint: disable=g-long-lambda
- lambda: mirrored_strategy.CoreMirroredStrategy(
- num_gpus=context.num_gpus()),
- required_gpus=1)],
- mode=["graph"]))
+@combinations.generate(
+ combinations.combine(
+ distribution=[
+ combinations.NamedDistribution(
+ "Mirrored",
+ # pylint: disable=g-long-lambda
+ lambda: mirrored_strategy.CoreMirroredStrategy(
+ num_gpus_per_worker=context.num_gpus()),
+ required_gpus=1),
+ combinations.NamedDistribution(
+ "CoreMirrored",
+ # pylint: disable=g-long-lambda
+ lambda: mirrored_strategy.CoreMirroredStrategy(
+ num_gpus_per_worker=context.num_gpus()),
+ required_gpus=1)
+ ],
+ mode=["graph"]))
class MultiWorkerMirroredStrategyTest(
multi_worker_test_base.MultiWorkerTestBase,
strategy_test_lib.DistributionTestBase):
@@ -1361,7 +1371,16 @@
expected_input_pipeline_id=None)
iterator = distribution.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
- iterator, distribution.worker_devices, expected_values, sess)
+ iterator, distribution.extended.worker_devices, expected_values, sess)
+
+ def testUpdateConfigProto(self, distribution):
+ distribution.configure(cluster_spec={"worker": ["fake1", "fake2"]})
+
+ config_proto = config_pb2.ConfigProto()
+ new_config = distribution.update_config_proto(config_proto)
+
+ # Verify isolate_session_state
+ self.assertTrue(new_config.isolate_session_state)
class MultiWorkerMirroredStrategyTestWithChief(
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 2f6d385..e322b6a 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -20,13 +20,14 @@
import six
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import values
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
@@ -68,7 +69,9 @@
def _make_dataset_iterator(self, dataset):
"""Make iterator from dataset without splitting the batch."""
- return values.DatasetIterator(dataset, [("/job:localhost", [self._device])])
+ worker = device_util.canonicalize("/device:CPU:0")
+ worker_device_pairs = [(worker, [self._device])]
+ return values.DatasetIterator(dataset, worker_device_pairs)
def _distribute_dataset(self, dataset_fn):
return values.PerReplicaDataset(
@@ -78,8 +81,10 @@
self,
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
+ worker = device_util.canonicalize("/device:CPU:0")
+ worker_device_pairs = [(worker, [self._device])]
return values.InputFunctionIterator(
- input_fn, [("/job:localhost", [self._device])],
+ input_fn, worker_device_pairs,
[distribute_lib.InputContext()])
def _broadcast_to(self, tensor, destinations):
@@ -194,6 +199,11 @@
def should_save_summary(self):
return True
+ # TODO(priyag): Delete this once all strategies use global batch size.
+ @property
+ def _global_batch_size(self):
+ return True
+
class _OneDeviceReplicaContext(distribute_lib.ReplicaContext):
"""ReplicaContext for OneDeviceStrategy."""
@@ -205,9 +215,5 @@
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32))
@property
- def device(self):
- raise RuntimeError("Use .devices instead")
-
- @property
def devices(self):
- return [self._distribution_strategy.worker_devices[0]]
+ return [self._distribution_strategy.extended.worker_devices[0]]
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py
index b0a2ba3..d46cd6f 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py
@@ -55,7 +55,7 @@
expected_input_pipeline_id=0)
iterator = d.make_input_fn_iterator(input_fn)
self._test_input_fn_iterator(
- iterator, d.worker_devices, expected_values)
+ iterator, d.extended.worker_devices, expected_values)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 6fc81a1..75ee41c 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -18,8 +18,12 @@
from __future__ import division
from __future__ import print_function
+import copy
+
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
@@ -30,8 +34,6 @@
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_setter
-from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
_LOCAL_CPU = "/device:CPU:0"
@@ -197,7 +199,7 @@
def _initialize_local(self, num_gpus_per_worker):
"""Initialize internal devices for local training."""
- self._worker_device = "/job:localhost"
+ self._worker_device = device_util.canonicalize("/device:CPU:0")
# Define compute devices which is a list of device strings and one for each
# replica. When there are GPUs, replicate operations on these GPUs.
# Otherwise, place operations on CPU.
@@ -462,21 +464,27 @@
self._initialize_multi_worker(self._num_gpus_per_worker,
self._cluster_spec, task_type, task_id)
- if not session_config or not self._cluster_spec:
- return
+ if session_config:
+ session_config.CopyFrom(self._update_config_proto(session_config))
- session_config.isolate_session_state = False
+ def _update_config_proto(self, config_proto):
+ updated_config = copy.deepcopy(config_proto)
+ if not self._cluster_spec:
+ updated_config.isolate_session_state = True
+ return updated_config
- assert self._cluster_spec
+ updated_config.isolate_session_state = False
+
assert self._task_type
assert self._task_id is not None
# The device filters prevent communication between workers.
if self._task_type not in ["chief", "worker"]:
- return
- del session_config.device_filters[:]
- session_config.device_filters.extend(
+ return updated_config
+ del updated_config.device_filters[:]
+ updated_config.device_filters.extend(
["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
+ return updated_config
@property
def _num_replicas_in_sync(self):
@@ -510,3 +518,8 @@
@property
def should_save_summary(self):
return self._is_chief
+
+ # TODO(priyag): Delete this once all strategies use global batch size.
+ @property
+ def _global_batch_size(self):
+ return False
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index b4c098a..4debe72 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -28,6 +28,8 @@
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
@@ -46,8 +48,6 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import device_util
-from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import training_util
CHIEF = run_config.TaskType.CHIEF
@@ -522,7 +522,7 @@
expected_values):
distribution, master_target, config = self._get_test_objects(
task_type, task_id, num_gpus)
- devices = distribution.worker_devices
+ devices = distribution.extended.worker_devices
with ops.Graph().as_default(), \
self.cached_session(config=config,
@@ -656,6 +656,33 @@
num_gpus_per_worker=context.num_gpus())
self._test_global_step_update(strategy)
+ def testUpdateConfigProtoMultiWorker(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ distribution.configure(
+ cluster_spec=self._cluster_spec, task_type='worker', task_id=1)
+
+ config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden'])
+
+ new_config = distribution.update_config_proto(config_proto)
+
+ # Verify device filters.
+ self.assertEqual(['/job:worker/task:1', '/job:ps'],
+ new_config.device_filters)
+
+ # Verify isolate_session_state
+ self.assertFalse(new_config.isolate_session_state)
+
+ def testUpdateConfigProtoLocal(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+
+ config_proto = config_pb2.ConfigProto()
+ new_config = distribution.update_config_proto(config_proto)
+
+ # Verify isolate_session_state
+ self.assertTrue(new_config.isolate_session_state)
+
class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
parameterized.TestCase):
@@ -698,9 +725,9 @@
v = variable_scope.get_variable('v', initializer=10.0)
_ = v * v
v, = tape.watched_variables()
- w = distribution.value_container(v)
+ w = distribution.extended.value_container(v)
self.assertIs(values.AggregatingVariable, type(w))
- distribution.call_for_each_replica(f)
+ distribution.extended.call_for_each_replica(f)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index de0abc6..756e5bd 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -19,6 +19,7 @@
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop
@@ -33,7 +34,6 @@
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribution_strategy_context as ds_context
from tensorflow.python.training import optimizer
@@ -191,17 +191,18 @@
def _test_replica_id(self, d):
with d.scope():
- expected_devices = [False] * len(d.worker_devices)
+ expected_devices = [False] * len(d.extended.worker_devices)
def mark_devices_fn():
replica_id = self.evaluate(
ds_context.get_replica_context().replica_id_in_sync_group)
- self.assertLess(replica_id, len(d.worker_devices))
+ self.assertLess(replica_id, len(d.extended.worker_devices))
self.assertFalse(expected_devices[replica_id])
expected_devices[replica_id] = True
d.call_for_each_replica(mark_devices_fn)
- self.assertAllEqual(expected_devices, [True] * len(d.worker_devices))
+ self.assertAllEqual(expected_devices,
+ [True] * len(d.extended.worker_devices))
def _test_call_and_merge_exceptions(self, dist):
with dist.scope():
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 314dcc5..1f302fd 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -21,6 +21,7 @@
from __future__ import division
from __future__ import print_function
+import copy
import functools
from tensorflow.contrib.tpu.python.ops import tpu_ops
@@ -28,6 +29,8 @@
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
@@ -40,8 +43,6 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
@@ -254,7 +255,7 @@
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
output_shapes = multi_worker_iterator.output_shapes
shapes = nest.flatten(output_shapes)
- if any([not s.is_fully_defined() for s in shapes]):
+ if any(not s.is_fully_defined() for s in shapes):
raise ValueError(
"TPU currently requires fully defined shapes. Either use "
"set_shape() on the input tensors or use "
@@ -539,10 +540,20 @@
task_id=None):
del cluster_spec, task_type, task_id
if session_config:
- session_config.isolate_session_state = True
- cluster_spec = self._tpu_cluster_resolver.cluster_spec()
- if cluster_spec:
- session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+ session_config.CopyFrom(self._update_config_proto(session_config))
+
+ def _update_config_proto(self, config_proto):
+ updated_config = copy.deepcopy(config_proto)
+ updated_config.isolate_session_state = True
+ cluster_spec = self._tpu_cluster_resolver.cluster_spec()
+ if cluster_spec:
+ updated_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+ return updated_config
+
+ # TODO(priyag): Delete this once all strategies use global batch size.
+ @property
+ def _global_batch_size(self):
+ return True
class _TPUReplicaContext(distribute_lib.ReplicaContext):
@@ -557,12 +568,8 @@
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32))
@property
- def device(self):
- raise RuntimeError("Use .devices instead")
-
- @property
def devices(self):
distribute_lib.require_replica_context(self)
ds = self._distribution_strategy
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
- return [ds.worker_devices[replica_id]]
+ return [ds.extended.worker_devices[replica_id]]
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 855b9c2..538b859 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -25,6 +25,8 @@
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@@ -39,8 +41,6 @@
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
-from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py
index 29eeaf4..ab3c071 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/normal_conjugate_posteriors_test.py
@@ -82,7 +82,7 @@
x = constant_op.constant(
[[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], [2.5, -2.5, -4.0, 0.0, 1.0, -2.0]],
dtype=dtypes.float32)
- s = math_ops.reduce_sum(x, reduction_indices=[1])
+ s = math_ops.reduce_sum(x, axis=[1])
x = array_ops.transpose(x) # Reshape to shape (6, 2)
n = constant_op.constant([6] * 2)
prior = distributions.Normal(loc=mu0, scale=sigma0)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
index a60056c..cdee30b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py
@@ -147,14 +147,13 @@
x = chol_w.sample(10000, seed=42)
self.assertAllEqual((10000, 3, 3), x.get_shape())
- moment1_estimate = math_ops.reduce_mean(x, reduction_indices=[0]).eval()
+ moment1_estimate = math_ops.reduce_mean(x, axis=[0]).eval()
self.assertAllClose(chol_w.mean().eval(), moment1_estimate, rtol=0.05)
# The Variance estimate uses the squares rather than outer-products
# because Wishart.Variance is the diagonal of the Wishart covariance
# matrix.
- variance_estimate = (math_ops.reduce_mean(
- math_ops.square(x), reduction_indices=[0]) -
+ variance_estimate = (math_ops.reduce_mean(math_ops.square(x), axis=[0]) -
math_ops.square(moment1_estimate)).eval()
self.assertAllClose(
chol_w.variance().eval(), variance_estimate, rtol=0.05)
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py
index 15c241d..74765f1 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py
@@ -168,7 +168,7 @@
# log_normalization = 1 + reduce_sum(exp(logits))
# -log_normalization + reduce_sum(logits - log_normalization)
log_normalization = nn_ops.softplus(
- math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True))
+ math_ops.reduce_logsumexp(x, axis=-1, keepdims=True))
return array_ops.squeeze(
(-log_normalization + math_ops.reduce_sum(
x - log_normalization, axis=-1, keepdims=True)), axis=-1)
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index c88c0f5..566246d 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -24,6 +24,7 @@
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import smart_cond
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -354,9 +355,10 @@
def write_summary_f():
summary_ops.scalar(name=self.name, tensor=t)
return t
- control_flow_ops.cond(write_summary,
+ smart_cond.smart_cond(write_summary,
write_summary_f,
- lambda: t)
+ lambda: t,
+ name="")
return t
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 9d2d172..39e5957 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -49,18 +49,6 @@
self.assertEqual(dtypes.float64, m.dtype)
self.assertEqual(dtypes.float64, m.result().dtype)
- def testSummaryArg(self):
- m = metrics.Mean()
- m([1, 10, 100])
- m(1000)
- m([10000.0, 100000.0])
- self.assertEqual(111111.0/6, m.result(write_summary=True).numpy())
- self.assertEqual(111111.0/6, m.result(write_summary=False).numpy())
- with self.assertRaises(ValueError):
- m.result(write_summary=5)
- with self.assertRaises(ValueError):
- m.result(write_summary=[True])
-
def testVariableCollections(self):
with context.graph_mode(), ops.Graph().as_default():
m = metrics.Mean()
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index f801d9a..5cc0c4f 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -24,7 +24,7 @@
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
-from tensorflow.python.keras.engine import base_layer as keras_base_layer
+from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.layers import base
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@@ -220,7 +220,7 @@
avoid_names = parent_network._owned_layers
name_uid_map = parent_network._sub_layer_name_uids
else:
- name_uid_map = keras_base_layer.get_default_graph_uid_map()
+ name_uid_map = base_layer_utils.get_default_graph_uid_map()
# Figure out which names we have to avoid based on which variable scope
# we're nested in.
strip_name = self._default_parent_variable_scope.name
diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py
index 4454abf..8c35ddd 100644
--- a/tensorflow/contrib/eager/python/tfe_test.py
+++ b/tensorflow/contrib/eager/python/tfe_test.py
@@ -87,8 +87,8 @@
x += 1.
# Without a device context, heuristics are used to place ops.
# In this case, ops.reduce_mean runs on the GPU.
- reduction_indices = range(x.shape.ndims)
- m = math_ops.reduce_mean(x, reduction_indices)
+ axis = range(x.shape.ndims)
+ m = math_ops.reduce_mean(x, axis)
# m is on GPU, bring it back to CPU and compare.
self.assertEqual(3.5, m.cpu().numpy())
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 37f253d..a888379 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -16,7 +16,6 @@
srcs_version = "PY2AND3",
deps = [
":boosted_trees",
- ":dnn",
":dnn_with_layer_annotations",
":early_stopping",
":expect_tensorflow_estimator_installed",
@@ -25,7 +24,6 @@
":extenders",
":head",
":hooks",
- ":linear",
":logit_fns",
":multi_head",
":replicate_model_fn",
@@ -48,18 +46,6 @@
)
py_library(
- name = "dnn",
- srcs = ["python/estimator/dnn.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":expect_tensorflow_estimator_installed",
- "//tensorflow:tensorflow_py_no_contrib",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:dnn",
- ],
-)
-
-py_library(
name = "dnn_with_layer_annotations",
srcs = ["python/estimator/dnn_with_layer_annotations.py"],
srcs_version = "PY2AND3",
@@ -145,17 +131,6 @@
)
py_library(
- name = "linear",
- srcs = ["python/estimator/linear.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":expect_tensorflow_estimator_installed",
- "//tensorflow/python/estimator",
- "//tensorflow/python/estimator:linear",
- ],
-)
-
-py_library(
name = "logit_fns",
srcs = [
"python/estimator/logit_fns.py",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index 80d5962..7d61247 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -58,8 +58,6 @@
'multi_label_head',
'poisson_regression_head',
'regression_head',
- 'DNNEstimator',
- 'LinearEstimator',
'boosted_trees_classifier_train_in_memory',
'boosted_trees_regressor_train_in_memory',
'call_logit_fn',
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
deleted file mode 100644
index 7894418..0000000
--- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""dnn_linear_combined python module.
-
-Importing from tensorflow.python.estimator is unsupported
-and will soon break!
-"""
-# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow_estimator.contrib.estimator.python.estimator import dnn_linear_combined
-
-# Include attrs that start with single underscore.
-_HAS_DYNAMIC_ATTRIBUTES = True
-dnn_linear_combined.__all__ = [
- s for s in dir(dnn_linear_combined) if not s.startswith('__')
-]
-
-from tensorflow_estimator.contrib.estimator.python.estimator.dnn_linear_combined import *
diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py
deleted file mode 100644
index b6a4444..0000000
--- a/tensorflow/contrib/estimator/python/estimator/linear.py
+++ /dev/null
@@ -1,32 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""linear python module.
-
-Importing from tensorflow.python.estimator is unsupported
-and will soon break!
-"""
-# pylint: disable=unused-import,g-bad-import-order,g-import-not-at-top,wildcard-import
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow_estimator.contrib.estimator.python.estimator import linear
-
-# Include attrs that start with single underscore.
-_HAS_DYNAMIC_ATTRIBUTES = True
-linear.__all__ = [s for s in dir(linear) if not s.startswith('__')]
-
-from tensorflow_estimator.contrib.estimator.python.estimator.linear import *
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index cd747df..53efae1 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -66,6 +66,7 @@
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:smart_cond",
+ "//tensorflow/python:sort_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:state_ops",
"//tensorflow/python:state_ops_gen",
@@ -311,17 +312,3 @@
"//third_party/py/numpy",
],
)
-
-py_test(
- name = "sort_ops_test",
- size = "medium",
- srcs = ["python/ops/sort_ops_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":framework_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:random_ops",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py
index 1921a77..42184a4 100644
--- a/tensorflow/contrib/framework/python/ops/sort_ops.py
+++ b/tensorflow/contrib/framework/python/ops/sort_ops.py
@@ -22,173 +22,7 @@
from __future__ import division
from __future__ import print_function
-import numpy as np
+from tensorflow.python.ops import sort_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops as framework_ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-
-
-def sort(values, axis=-1, direction='ASCENDING', name=None):
- """Sorts a tensor.
-
- Args:
- values: 1-D or higher numeric `Tensor`.
- axis: The axis along which to sort. The default is -1, which sorts the last
- axis.
- direction: The direction in which to sort the values (`'ASCENDING'` or
- `'DESCENDING'`).
- name: Optional name for the operation.
-
- Returns:
- A `Tensor` with the same dtype and shape as `values`, with the elements
- sorted along the given `axis`.
-
- Raises:
- ValueError: If axis is not a constant scalar, or the direction is invalid.
- """
- with framework_ops.name_scope(name, 'sort'):
- return _sort_or_argsort(values, axis, direction, return_argsort=False)
-
-
-def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):
- """Returns the indices of a tensor that give its sorted order along an axis.
-
- For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to
- `tf.sort(values)`. For higher dimensions, the output has the same shape as
- `values`, but along the given axis, values represent the index of the sorted
- element in that slice of the tensor at the given position.
-
- Args:
- values: 1-D or higher numeric `Tensor`.
- axis: The axis along which to sort. The default is -1, which sorts the last
- axis.
- direction: The direction in which to sort the values (`'ASCENDING'` or
- `'DESCENDING'`).
- stable: If True, equal elements in the original tensor will not be
- re-ordered in the returned order. Unstable sort is not yet implemented,
- but will eventually be the default for performance reasons. If you
- require a stable order, pass `stable=True` for forwards compatibility.
- name: Optional name for the operation.
-
- Returns:
- An int32 `Tensor` with the same shape as `values`. The indices that would
- sort each slice of the given `values` along the given `axis`.
-
- Raises:
- ValueError: If axis is not a constant scalar, or the direction is invalid.
- """
- del stable # Unused.
- with framework_ops.name_scope(name, 'argsort'):
- return _sort_or_argsort(values, axis, direction, return_argsort=True)
-
-
-def _sort_or_argsort(values, axis, direction, return_argsort):
- """Internal sort/argsort implementation.
-
- Args:
- values: The input values.
- axis: The axis along which to sort.
- direction: 'ASCENDING' or 'DESCENDING'.
- return_argsort: Whether to return the argsort result.
-
- Returns:
- Either the sorted values, or the indices of the sorted values in the
- original tensor. See the `sort` and `argsort` docstrings.
-
- Raises:
- ValueError: If axis is not a constant scalar, or the direction is invalid.
- """
- if direction not in _SORT_IMPL:
- raise ValueError('%s should be one of %s' %
- (direction, ', '.join(sorted(_SORT_IMPL.keys()))))
- # Axis must be an integer, not a Tensor.
- axis = framework_ops.convert_to_tensor(axis, name='axis')
- axis_static = tensor_util.constant_value(axis)
- if axis.shape.ndims != 0 or axis_static is None:
- raise ValueError('axis must be a constant scalar')
- axis_static = int(axis_static) # Avoids NumPy casting error
-
- values = framework_ops.convert_to_tensor(values, name='values')
-
- return _SORT_IMPL[direction](values, axis_static, return_argsort)
-
-
-def _descending_sort(values, axis, return_argsort=False):
- """Sorts values in reverse using `top_k`.
-
- Args:
- values: Tensor of numeric values.
- axis: Index of the axis which values should be sorted along.
- return_argsort: If False, return the sorted values. If True, return the
- indices that would sort the values.
-
- Returns:
- The sorted values.
- """
- k = array_ops.shape(values)[axis]
- rank = array_ops.rank(values)
- static_rank = values.shape.ndims
- # Fast path: sorting the last axis.
- if axis == -1 or axis + 1 == values.get_shape().ndims:
- top_k_input = values
- transposition = None
- else:
- # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`.
- if axis < 0:
- # Calculate the actual axis index if counting from the end. Use the static
- # rank if available, or else make the axis back into a tensor.
- axis += static_rank or rank
- if static_rank is not None:
- # Prefer to calculate the transposition array in NumPy and make it a
- # constant.
- transposition = constant_op.constant(
- np.r_[
- # Axes up to axis are unchanged.
- np.arange(axis),
- # Swap axis and rank - 1.
- [static_rank - 1],
- # Axes in [axis + 1, rank - 1) are unchanged.
- np.arange(axis + 1, static_rank - 1),
- # Swap axis and rank - 1.
- [axis]],
- name='transposition')
- else:
- # Generate the transposition array from the tensors.
- transposition = array_ops.concat(
- [
- # Axes up to axis are unchanged.
- math_ops.range(axis),
- # Swap axis and rank - 1.
- [rank - 1],
- # Axes in [axis + 1, rank - 1) are unchanged.
- math_ops.range(axis + 1, rank - 1),
- # Swap axis and rank - 1.
- [axis]
- ],
- axis=0)
- top_k_input = array_ops.transpose(values, transposition)
-
- values, indices = nn_ops.top_k(top_k_input, k)
- return_value = indices if return_argsort else values
- if transposition is not None:
- # transposition contains a single cycle of length 2 (swapping 2 elements),
- # so it is an involution (it is its own inverse).
- return_value = array_ops.transpose(return_value, transposition)
- return return_value
-
-
-def _ascending_sort(values, axis, return_argsort=False):
- # Negate the values to get the ascending order from descending sort.
- values_or_indices = _descending_sort(-values, axis, return_argsort)
- # If not argsort, negate the values again.
- return values_or_indices if return_argsort else -values_or_indices
-
-
-_SORT_IMPL = {
- 'ASCENDING': _ascending_sort,
- 'DESCENDING': _descending_sort,
-}
+sort = sort_ops.sort
+argsort = sort_ops.argsort
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index 219cc19..3593b50 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -113,7 +113,8 @@
add_summaries=None,
use_loss_summaries=True,
config=None,
- warm_start_from=None):
+ warm_start_from=None,
+ is_chief=True):
"""Initializes a GANEstimator instance.
Args:
@@ -154,6 +155,8 @@
config: `RunConfig` object to configure the runtime settings.
warm_start_from: A filepath to a checkpoint or saved model, or a
WarmStartSettings object to configure initialization.
+ is_chief: Whether or not this Estimator is running on a chief or worker.
+ Needs to be set appropriately if using SyncReplicasOptimizers.
Raises:
ValueError: If loss functions aren't callable.
@@ -187,7 +190,7 @@
return _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn, use_loss_summaries)
+ get_hooks_fn, use_loss_summaries, is_chief)
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
@@ -215,7 +218,7 @@
def _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn=None, use_loss_summaries=True):
+ get_hooks_fn=None, use_loss_summaries=True, is_chief=True):
"""Get the EstimatorSpec for the current mode."""
if mode == model_fn_lib.ModeKeys.PREDICT:
estimator_spec = model_fn_lib.EstimatorSpec(
@@ -236,7 +239,7 @@
else discriminator_optimizer)
get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
estimator_spec = _get_train_estimator_spec(
- gan_model, gan_loss, gopt, dopt, get_hooks_fn)
+ gan_model, gan_loss, gopt, dopt, get_hooks_fn, is_chief=is_chief)
return estimator_spec
@@ -321,11 +324,11 @@
def _get_train_estimator_spec(
gan_model, gan_loss, generator_optimizer, discriminator_optimizer,
- get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops):
+ get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops, is_chief=True):
"""Return an EstimatorSpec for the train case."""
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer,
- discriminator_optimizer)
+ discriminator_optimizer, is_chief=is_chief)
training_hooks = get_hooks_fn(train_ops)
return model_fn_lib.EstimatorSpec(
loss=scalar_loss,
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 3d6bdab..bc90210 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -48,6 +48,7 @@
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import learning_rate_decay
+from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training
from tensorflow.python.training import training_util
@@ -82,7 +83,7 @@
self.assertEqual(generator_inputs, gan_model.generator_inputs)
self.assertIsNotNone(gan_model.generated_data)
- self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer
+ self.assertLen(gan_model.generator_variables, 2) # 1 FC layer
self.assertIsNotNone(gan_model.generator_fn)
if mode == model_fn_lib.ModeKeys.PREDICT:
self.assertIsNone(gan_model.real_data)
@@ -95,7 +96,7 @@
self.assertIsNotNone(gan_model.real_data)
self.assertIsNotNone(gan_model.discriminator_real_outputs)
self.assertIsNotNone(gan_model.discriminator_gen_outputs)
- self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer
+ self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer
self.assertIsNotNone(gan_model.discriminator_scope)
self.assertIsNotNone(gan_model.discriminator_fn)
@@ -121,6 +122,7 @@
def dummy_loss_fn(gan_model, add_summaries=True):
+ del add_summaries
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
gan_model.discriminator_gen_outputs)
@@ -168,6 +170,35 @@
self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
self.assertIsNotNone(spec.eval_metric_ops)
+ def test_get_sync_estimator_spec(self):
+ """Make sure spec is loaded with sync hooks for sync opts."""
+
+ def get_sync_optimizer():
+ return sync_replicas_optimizer.SyncReplicasOptimizer(
+ training.GradientDescentOptimizer(learning_rate=1.0),
+ replicas_to_aggregate=1)
+
+ with ops.Graph().as_default():
+ self._gan_model = get_dummy_gan_model()
+ g_opt = get_sync_optimizer()
+ d_opt = get_sync_optimizer()
+
+ spec = estimator._get_estimator_spec(
+ model_fn_lib.ModeKeys.TRAIN,
+ self._gan_model,
+ generator_loss_fn=dummy_loss_fn,
+ discriminator_loss_fn=dummy_loss_fn,
+ get_eval_metric_ops_fn=get_metrics,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt)
+
+ self.assertLen(spec.training_hooks, 4)
+ sync_opts = [
+ hook._sync_optimizer for hook in spec.training_hooks if
+ isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
+ self.assertLen(sync_opts, 2)
+ self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
+
# TODO(joelshor): Add pandas test.
class GANEstimatorIntegrationTest(test.TestCase):
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index df0342c..a0a86c6 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -36,7 +36,6 @@
from __future__ import division
from __future__ import print_function
-import numpy as np
from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib
from tensorflow.python.framework import ops
@@ -47,7 +46,6 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
from tensorflow.python.ops.losses import util
from tensorflow.python.summary import summary
@@ -740,11 +738,16 @@
def _validate_distributions(distributions):
if not isinstance(distributions, (list, tuple)):
raise ValueError('`distributions` must be a list or tuple. Instead, '
- 'found %s.', type(distributions))
+ 'found %s.' % type(distributions))
for x in distributions:
- if not isinstance(x, ds.Distribution):
+ # We used to check with `isinstance(x, tf.distributions.Distribution)`.
+ # However, distributions have migrated to `tfp.distributions.Distribution`,
+ # which is a new code repo, so we can't check this way anymore until
+ # TF-GAN is migrated to a new repo as well.
+ # This new check is not sufficient, but is a useful heuristic for now.
+ if not callable(getattr(x, 'log_prob', None)):
raise ValueError('`distributions` must be a list of `Distributions`. '
- 'Instead, found %s.', type(x))
+ 'Instead, found %s.' % type(x))
def _validate_information_penalty_inputs(
@@ -817,7 +820,7 @@
Returns:
A scalar tensor with the global norm.
"""
- if np.all([x is None for x in tensor_list]):
+ if all(x is None for x in tensor_list):
return 0.0
list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index b9ac1bf..969b684 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -213,7 +213,8 @@
collections.namedtuple('GANTrainOps', (
'generator_train_op',
'discriminator_train_op',
- 'global_step_inc_op'
+ 'global_step_inc_op',
+ 'train_hooks'
))):
"""GANTrainOps contains the training ops.
@@ -221,8 +222,17 @@
generator_train_op: Op that performs a generator update step.
discriminator_train_op: Op that performs a discriminator update step.
global_step_inc_op: Op that increments the shared global step.
+ train_hooks: a list or tuple containing hooks related to training that need
+ to be populated when training ops are instantiated. Used primarily for
+ sync hooks.
"""
+ def __new__(cls, generator_train_op, discriminator_train_op,
+ global_step_inc_op, train_hooks=()):
+ return super(GANTrainOps, cls).__new__(cls, generator_train_op,
+ discriminator_train_op,
+ global_step_inc_op, train_hooks)
+
class GANTrainSteps(
collections.namedtuple('GANTrainSteps', (
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index cf5b9d9..4c7bee4 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -924,6 +924,7 @@
generator_optimizer,
discriminator_optimizer,
check_for_unused_update_ops=True,
+ is_chief=True,
# Optional args to pass directly to the `create_train_op`.
**kwargs):
"""Returns GAN train ops.
@@ -939,6 +940,8 @@
discriminator_optimizer: The optimizer for the discriminator updates.
check_for_unused_update_ops: If `True`, throws an exception if there are
update ops outside of the generator or discriminator scopes.
+ is_chief: Specifies whether or not the training is being run by the primary
+ replica during replica training.
**kwargs: Keyword args to pass directly to
`training.create_train_op` for both the generator and
discriminator train op.
@@ -980,6 +983,9 @@
kwargs, model.generator_scope.name, model.discriminator_scope.name,
check_for_unused_update_ops)
+ # Get the sync hooks if these are needed.
+ sync_hooks = []
+
generator_global_step = None
if isinstance(generator_optimizer,
sync_replicas_optimizer.SyncReplicasOptimizer):
@@ -995,6 +1001,7 @@
trainable=False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
gen_update_ops += [generator_global_step.assign(global_step)]
+ sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief))
with ops.name_scope('generator_train'):
gen_train_op = training.create_train_op(
total_loss=loss.generator_loss,
@@ -1016,6 +1023,7 @@
trainable=False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
dis_update_ops += [discriminator_global_step.assign(global_step)]
+ sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief))
with ops.name_scope('discriminator_train'):
disc_train_op = training.create_train_op(
total_loss=loss.discriminator_loss,
@@ -1025,7 +1033,8 @@
update_ops=dis_update_ops,
**kwargs)
- return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
+ return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc,
+ sync_hooks)
# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive
@@ -1066,7 +1075,7 @@
train_steps.generator_train_steps)
discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
train_steps.discriminator_train_steps)
- return [generator_hook, discriminator_hook]
+ return [generator_hook, discriminator_hook] + list(train_ops.train_hooks)
return get_hooks
@@ -1126,7 +1135,7 @@
g_hook = RunTrainOpsHook(g_op, num_g_steps)
d_hook = RunTrainOpsHook(d_op, num_d_steps)
- return [joint_hook, g_hook, d_hook]
+ return [joint_hook, g_hook, d_hook] + list(train_ops.train_hooks)
return get_hooks
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 31d9e82..841f25c 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -759,7 +759,7 @@
# For [pool_size, ?), the pool is full, tensor2 must be equal to some
# historical values of tensor1 (which is previously stored in the
# pool).
- self.assertTrue(any([(v == t2).all() for v in history_values]))
+ self.assertTrue(any((v == t2).all() for v in history_values))
def _make_new_model_and_check(self, model, pool_size):
pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size)
@@ -836,6 +836,9 @@
self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
+ # Make sure there are no training hooks populated accidentally.
+ self.assertEmpty(train_ops.train_hooks)
+
# TODO(joelshor): Add a test to check that custom update op is run.
@parameterized.named_parameters(
('gan', create_gan_model, False),
@@ -925,6 +928,14 @@
# No new trainable variables should have been added.
self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars)
+ # Sync hooks should be populated in the GANTrainOps.
+ self.assertLen(train_ops.train_hooks, 2)
+ for hook in train_ops.train_hooks:
+ self.assertIsInstance(
+ hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
+ sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks]
+ self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
+
g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)
@@ -958,6 +969,32 @@
coord.request_stop()
coord.join(g_threads + d_threads)
+ @parameterized.named_parameters(
+ ('is_chief', True),
+ ('is_not_chief', False),
+ )
+ def test_is_chief_in_train_hooks(self, is_chief):
+ """Make sure is_chief is propagated correctly to sync hooks."""
+ model = create_gan_model()
+ loss = train.gan_loss(model)
+ g_opt = get_sync_optimizer()
+ d_opt = get_sync_optimizer()
+ train_ops = train.gan_train_ops(
+ model,
+ loss,
+ g_opt,
+ d_opt,
+ is_chief=is_chief,
+ summarize_gradients=True,
+ colocate_gradients_with_ops=True)
+
+ self.assertLen(train_ops.train_hooks, 2)
+ for hook in train_ops.train_hooks:
+ self.assertIsInstance(
+ hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
+ is_chief_list = [hook._is_chief for hook in train_ops.train_hooks]
+ self.assertListEqual(is_chief_list, [is_chief, is_chief])
+
class GANTrainTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_train`."""
@@ -1035,6 +1072,44 @@
self.assertTrue(np.isscalar(final_loss))
self.assertEqual(17.0, final_loss)
+ @parameterized.named_parameters(
+ ('gan', create_gan_model),
+ ('callable_gan', create_callable_gan_model),
+ ('infogan', create_infogan_model),
+ ('callable_infogan', create_callable_infogan_model),
+ ('acgan', create_acgan_model),
+ ('callable_acgan', create_callable_acgan_model),
+ )
+ def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
+ model = create_gan_model_fn()
+ loss = train.gan_loss(model)
+
+ g_opt = get_sync_optimizer()
+ d_opt = get_sync_optimizer()
+ train_ops = train.gan_train_ops(
+ model,
+ loss,
+ g_opt,
+ d_opt,
+ summarize_gradients=True,
+ colocate_gradients_with_ops=True)
+
+ sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
+ self.assertLen(sequential_train_hooks, 4)
+ sync_opts = [
+ hook._sync_optimizer for hook in sequential_train_hooks if
+ isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
+ self.assertLen(sync_opts, 2)
+ self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
+
+ joint_train_hooks = train.get_joint_train_hooks()(train_ops)
+ self.assertLen(joint_train_hooks, 5)
+ sync_opts = [
+ hook._sync_optimizer for hook in joint_train_hooks if
+ isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
+ self.assertLen(sync_opts, 2)
+ self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
+
class PatchGANTest(test.TestCase, parameterized.TestCase):
"""Tests that functions work on PatchGAN style output."""
diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
index 94f522c..fbccbea 100644
--- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
+++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc
@@ -170,6 +170,14 @@
// Record "call" in active_ so that it can be aborted cleanly.
RegisterCall(call);
+ // RendezvousMgr already aborted, shouldn't send RPC call any more
+ if (!call->status().ok()) {
+ done(call->status(), Args(), Args(), Tensor(), false);
+ session()->worker_cache->ReleaseWorker(src_worker, rwi);
+ delete call;
+ return;
+ }
+
// Start "call".
Ref();
call->Start([this, call, src_worker, rwi, done]() {
diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
index 478b716..108da04 100644
--- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
+++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc
@@ -115,7 +115,7 @@
*context->device()->tensorflow_cpu_worker_threads();
Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
kCostPerChannel,
- [channel_count, &input_data, &output_data, &tranformation_matrix](
+ [&input_data, &output_data, &tranformation_matrix](
int64 start_channel, int64 end_channel) {
// Applying projection matrix to input RGB vectors.
const float* p = input_data.data() + start_channel * kChannelSize;
diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py
index 3327a9f..9e19884 100644
--- a/tensorflow/contrib/keras/api/keras/layers/__init__.py
+++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py
@@ -20,7 +20,7 @@
# Generic layers.
# pylint: disable=g-bad-import-order
-from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py
index 47cd01b..3b9fa1b 100644
--- a/tensorflow/contrib/keras/api/keras/utils/__init__.py
+++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py
@@ -30,6 +30,7 @@
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.utils.io_utils import HDF5Matrix
from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model
+from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.keras.utils.np_utils import normalize
from tensorflow.python.keras.utils.np_utils import to_categorical
from tensorflow.python.keras.utils.vis_utils import plot_model
diff --git a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py
index de75302..1626e55 100644
--- a/tensorflow/contrib/kernel_methods/python/kernel_estimators.py
+++ b/tensorflow/contrib/kernel_methods/python/kernel_estimators.py
@@ -90,7 +90,7 @@
mapped_column_name = column_name + "_MAPPED"
# Construct new feature columns based on provided kernel_mappers.
column_kernel_mappers = kernel_mappers_dict[feature_column]
- new_dim = sum([mapper.output_dim for mapper in column_kernel_mappers])
+ new_dim = sum(mapper.output_dim for mapper in column_kernel_mappers)
mapped_columns.add(
layers.feature_column.real_valued_column(mapped_column_name, new_dim))
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index 8015a57..295c721 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -21,6 +21,7 @@
import itertools
import math
+import sys
import numpy as np
@@ -36,6 +37,7 @@
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -48,11 +50,13 @@
assert num_shards > 0
assert num_shards <= vocab_size
- embedding_weights = partitioned_variables.create_partitioned_variables(
+ initializer = init_ops.truncated_normal_initializer(
+ mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
+ embedding_weights = list(variable_scope.get_variable(
+ "embedding_weights",
shape=[vocab_size, embed_dim],
- slicing=[num_shards, 1],
- initializer=init_ops.truncated_normal_initializer(
- mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
+ partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
+ initializer=initializer))
for w in embedding_weights:
w.initializer.run()
embedding_weights = [w.eval() for w in embedding_weights]
@@ -256,6 +260,13 @@
embedding_weights, sparse_ids, sparse_weights)
+# pylint: disable=invalid-name
+def local_variable_scope():
+ """Create a variable scope named like the caller function."""
+ return variable_scope.variable_scope(sys._getframe(1).f_code.co_name)
+# pylint: enable=invalid-name
+
+
class ScatteredEmbeddingLookupTest(test.TestCase):
def setUp(self):
@@ -266,17 +277,18 @@
assert num_shards > 0
assert num_shards <= size
- embedding_weights = partitioned_variables.create_partitioned_variables(
+ embedding_weights = list(variable_scope.get_variable(
+ "embedding_weights",
shape=[size],
- slicing=[num_shards],
+ partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=init_ops.truncated_normal_initializer(
- mean=0.0, stddev=1.0, dtype=dtypes.float32))
+ mean=0.0, stddev=1.0, dtype=dtypes.float32)))
for w in embedding_weights:
w.initializer.run()
return embedding_weights
def test_scattered_embedding_consistency(self):
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
@@ -288,7 +300,7 @@
embedding_lookup_result[1])
def test_scattered_embedding_multiple_partition(self):
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights(num_shards=7)
values = constant_op.constant([4, 4, 5])
@@ -304,7 +316,7 @@
self.assertGreater(embedding_diff, 0)
def test_scattered_embedding_coverage(self):
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
size = 8
embedding_weights = self._random_weights(size=size, num_shards=3)
values = constant_op.constant(["foo"])
@@ -316,7 +328,7 @@
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_scattered_embedding_multi_dimension(self):
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -329,7 +341,7 @@
embedding_lookup_result[1][2])
def test_scattered_embedding_lookup_sparse(self):
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = sparse_tensor_lib.SparseTensor(
values=["foo", "bar", "foo", "bar"],
@@ -358,7 +370,7 @@
embeds = np.random.randn(n_embed, d_embed)
idx = np.random.randint(0, n_embed, idx_shape)
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
embedded_np = embeds[idx]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -370,7 +382,7 @@
idx = np.random.randint(0, 5, 10)
idx2d = np.random.randint(0, 5, (10, 2))
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
embedded_np = embeds[idx]
embedded_np2d = embeds[idx2d]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -398,17 +410,18 @@
assert num_shards > 0
assert num_shards <= size
- embedding_weights = partitioned_variables.create_partitioned_variables(
+ embedding_weights = list(variable_scope.get_variable(
+ "embedding_weights",
shape=[size],
- slicing=[num_shards],
+ partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
initializer=init_ops.truncated_normal_initializer(
- mean=0.0, stddev=1.0, dtype=dtypes.float32))
+ mean=0.0, stddev=1.0, dtype=dtypes.float32)))
for w in embedding_weights:
w.initializer.run()
return embedding_weights
def test_hashed_embedding_consistency(self):
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
# The first three sampled_candidates are equal, so the first three
@@ -429,7 +442,7 @@
embedding_lookup_result[1][3])
def test_hashed_embedding_multi_dimension(self):
- with self.cached_session():
+ with self.cached_session(), local_variable_scope():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
diff --git a/tensorflow/contrib/layers/python/layers/encoders.py b/tensorflow/contrib/layers/python/layers/encoders.py
index f421122..3671633 100644
--- a/tensorflow/contrib/layers/python/layers/encoders.py
+++ b/tensorflow/contrib/layers/python/layers/encoders.py
@@ -84,8 +84,7 @@
if isinstance(ids, sparse_tensor.SparseTensor):
raise TypeError('ids are expected to be dense Tensor, got: %s', ids)
return math_ops.reduce_mean(
- embedding_ops.embedding_lookup(embeddings, ids),
- reduction_indices=1)
+ embedding_ops.embedding_lookup(embeddings, ids), axis=1)
def embed_sequence(ids,
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 222404b..00d819e 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -1015,8 +1015,7 @@
dense_id_tensor, depth=self.length, on_value=1.0, off_value=0.0)
# Reduce to get a multi-hot per example.
- return math_ops.reduce_sum(
- one_hot_id_tensor, reduction_indices=[output_rank - 1])
+ return math_ops.reduce_sum(one_hot_id_tensor, axis=[output_rank - 1])
@property
def _variable_shape(self):
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index ac9561c..403b522 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -35,6 +35,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.layers import base
from tensorflow.python.layers import convolutional as convolutional_layers
from tensorflow.python.layers import core as core_layers
@@ -1958,7 +1959,7 @@
self._reparam_offset = reparam_offset
self.data_format = data_format
self._channel_axis() # trigger ValueError early
- self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5)
+ self.input_spec = input_spec.InputSpec(min_ndim=3, max_ndim=5)
def _channel_axis(self):
try:
@@ -2015,7 +2016,7 @@
raise ValueError('The channel dimension of the inputs to `GDN` '
'must be defined.')
self._input_rank = input_shape.ndims
- self.input_spec = base.InputSpec(
+ self.input_spec = input_spec.InputSpec(
ndim=input_shape.ndims, axes={
channel_axis: num_channels
})
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 8ead633..0a4d2c6 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -3811,7 +3811,7 @@
image = random_ops.random_uniform((height, width, 3))
output = _layers.unit_norm(image, dim=dim, epsilon=1e-6)
norms = math_ops.sqrt(
- math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim))
+ math_ops.reduce_sum(math_ops.square(output), axis=dim))
shape = [height, width, 3]
del shape[dim]
@@ -3847,7 +3847,7 @@
image = array_ops.placeholder(dtypes.float32, (None, None, 3))
output = _layers.unit_norm(image, dim=dim, epsilon=1e-6)
norms = math_ops.sqrt(
- math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim))
+ math_ops.reduce_sum(math_ops.square(output), axis=dim))
with self.cached_session():
actual = norms.eval({image: placeholder_value})
diff --git a/tensorflow/contrib/layers/python/layers/regularizers_test.py b/tensorflow/contrib/layers/python/layers/regularizers_test.py
index 51faba3..5cb00b7 100644
--- a/tensorflow/contrib/layers/python/layers/regularizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/regularizers_test.py
@@ -141,7 +141,7 @@
dummy_regularizer = lambda x: math_ops.reduce_sum(2 * x)
array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
- expected = sum([2 * x for l in array_weights_list for x in l])
+ expected = sum(2 * x for l in array_weights_list for x in l)
with self.cached_session():
result = regularizers.apply_regularization(dummy_regularizer,
tensor_weights_list)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index 18ca421..10fbd60 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -150,10 +150,10 @@
"input_from_feature_columns",
values=tuple(six.itervalues(features)),
partitioner=input_layer_partitioner) as input_layer_scope:
- if all([
+ if all(
isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access
for fc in feature_columns
- ]):
+ ):
net = layers.input_from_feature_columns(
columns_to_tensors=features,
feature_columns=feature_columns,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 7a3cc8bd..2ade6b7 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -236,10 +236,10 @@
"input_from_feature_columns",
values=tuple(six.itervalues(features)),
partitioner=input_layer_partitioner) as dnn_input_scope:
- if all([
+ if all(
isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access
for fc in dnn_feature_columns
- ]):
+ ):
net = layers.input_from_feature_columns(
columns_to_tensors=features,
feature_columns=dnn_feature_columns,
@@ -292,8 +292,8 @@
linear_parent_scope,
values=tuple(six.itervalues(features)),
partitioner=linear_partitioner) as scope:
- if all([isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access
- for fc in linear_feature_columns]):
+ if all(isinstance(fc, feature_column_lib._FeatureColumn) # pylint: disable=protected-access
+ for fc in linear_feature_columns):
if joint_linear_weights:
linear_logits, _, _ = layers.joint_weighted_sum_from_feature_columns(
columns_to_tensors=features,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index 1d8a592..28c4964 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -668,7 +668,7 @@
sequences = centers + noise
inputs = array_ops.expand_dims(sequences, 2)
- labels = math_ops.reduce_mean(sequences, reduction_indices=[1])
+ labels = math_ops.reduce_mean(sequences, axis=[1])
return {'inputs': inputs}, labels
return input_fn
@@ -722,8 +722,8 @@
inputs = array_ops.expand_dims(math_ops.to_float(random_sequence), 2)
labels = math_ops.to_int32(
array_ops.squeeze(
- math_ops.reduce_sum(
- inputs, reduction_indices=[1]) > (sequence_length / 2.0)))
+ math_ops.reduce_sum(inputs, axis=[1]) > (
+ sequence_length / 2.0)))
return {'inputs': inputs}, labels
return input_fn
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 8bc869d..9132b22 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -1066,11 +1066,11 @@
chief_hooks = []
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
- saver_hook_exists = any([
+ saver_hook_exists = any(
isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
model_fn_ops.training_chief_hooks)
- ])
+ )
if not saver_hook_exists:
chief_hooks = [
basic_session_run_hooks.CheckpointSaverHook(
@@ -1493,7 +1493,7 @@
# pylint: disable=protected-access
class SKCompat(sklearn.BaseEstimator):
"""Scikit learn wrapper for TensorFlow Learn Estimator.
-
+
THIS CLASS IS DEPRECATED. See
[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
for general migration instructions.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 439b17e..9ee8d80 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -155,8 +155,8 @@
parent_scope,
values=tuple(six.itervalues(features)),
partitioner=partitioner) as scope:
- if all([isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access
- for fc in feature_columns]):
+ if all(isinstance(fc, feature_column._FeatureColumn) # pylint: disable=protected-access
+ for fc in feature_columns):
if joint_weights:
layer_fn = layers.joint_weighted_sum_from_feature_columns
else:
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index d8ac416..709a042 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -59,9 +59,8 @@
"""
# First, compute the sum of the losses over all elements:
start_index = max(0, weights.get_shape().ndims)
- reduction_indices = list(range(start_index, losses.get_shape().ndims))
- reduced_losses = math_ops.reduce_sum(
- losses, reduction_indices=reduction_indices)
+ axis = list(range(start_index, losses.get_shape().ndims))
+ reduced_losses = math_ops.reduce_sum(losses, axis=axis)
reduced_losses = math_ops.multiply(reduced_losses, weights)
return math_ops.reduce_sum(reduced_losses)
@@ -158,10 +157,9 @@
# First, count the number of nonzero weights:
if weights.get_shape().ndims >= 1:
- reduction_indices = list(range(1, weights.get_shape().ndims))
+ axis = list(range(1, weights.get_shape().ndims))
num_nonzero_per_batch = math_ops.reduce_sum(
- math_ops.to_float(math_ops.not_equal(weights, 0)),
- reduction_indices=reduction_indices)
+ math_ops.to_float(math_ops.not_equal(weights, 0)), axis=axis)
# Next, determine the number of elements that weights would broadcast to:
broadcast_dims = array_ops.slice(
@@ -577,16 +575,16 @@
if weights.get_shape().ndims is None:
raise ValueError("weights.get_shape().ndims cannot be None")
- reduction_indices = list(range(1, diffs.get_shape().ndims))
+ axis = list(range(1, diffs.get_shape().ndims))
sum_squares_diff_per_batch = math_ops.reduce_sum(
- math_ops.square(diffs), reduction_indices=reduction_indices)
+ math_ops.square(diffs), axis=axis)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
term1 = 2.0 * math_ops.div_no_nan(
sum_squares_diff_per_batch, num_present_per_batch, name="value")
- sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
+ sum_diff = math_ops.reduce_sum(diffs, axis=axis)
term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
math_ops.square(num_present_per_batch),
@@ -645,7 +643,7 @@
radial_diffs = math_ops.multiply(predictions, labels)
losses = 1 - math_ops.reduce_sum(
- radial_diffs, reduction_indices=[
+ radial_diffs, axis=[
axis,
])
return compute_weighted_loss(losses, weights, scope=scope)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 09fe65b..7b432f8 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -3416,7 +3416,7 @@
predictions.get_shape().assert_is_compatible_with(labels.get_shape())
radial_diffs = math_ops.multiply(predictions, labels)
radial_diffs = math_ops.reduce_sum(
- radial_diffs, reduction_indices=[
+ radial_diffs, axis=[
dim,
], keepdims=True)
mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None,
diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
index f0ce6fe..1fa5c8c 100644
--- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py
+++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
@@ -21,6 +21,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.layers import base
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
@@ -119,7 +120,7 @@
self.bias_initializer = bias_initializer
self.kernel_regularizer = kernel_regularizer
self.bias_regularizer = bias_regularizer
- self.input_spec = base.InputSpec(ndim=self.rank + 2)
+ self.input_spec = input_spec.InputSpec(ndim=self.rank + 2)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
@@ -171,7 +172,7 @@
dtype=self.dtype)
else:
self.bias = None
- self.input_spec = base.InputSpec(
+ self.input_spec = input_spec.InputSpec(
ndim=self.rank + 2, axes={channel_axis: input_dim})
self.built = True
@@ -393,14 +394,14 @@
self.bias_initializer = bias_initializer
self.kernel_regularizer = kernel_regularizer
self.bias_regularizer = bias_regularizer
- self.input_spec = base.InputSpec(min_ndim=2)
+ self.input_spec = input_spec.InputSpec(min_ndim=2)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if tensor_shape.dimension_value(input_shape[-1]) is None:
raise ValueError('The last dimension of the inputs to `Dense` '
'should be defined. Found `None`.')
- self.input_spec = base.InputSpec(
+ self.input_spec = input_spec.InputSpec(
min_ndim=2, axes={-1: tensor_shape.dimension_value(input_shape[-1])})
self.kernel = self.add_variable(
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer.py b/tensorflow/contrib/opt/python/training/lars_optimizer.py
index a8dafd9..bc18177 100644
--- a/tensorflow/contrib/opt/python/training/lars_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer.py
@@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@@ -162,3 +163,14 @@
math_ops.cast(self._momentum_tensor, grad.dtype),
use_locking=self._use_locking,
use_nesterov=self._use_nesterov)
+
+ def _prepare(self):
+ learning_rate = self._learning_rate
+ if callable(learning_rate):
+ learning_rate = learning_rate()
+ self._learning_rate_tensor = ops.convert_to_tensor(
+ learning_rate, name="learning_rate")
+ momentum = self._momentum
+ if callable(momentum):
+ momentum = momentum()
+ self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum")
diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD
index 835fb4a..6e40140 100644
--- a/tensorflow/contrib/optimizer_v2/BUILD
+++ b/tensorflow/contrib/optimizer_v2/BUILD
@@ -48,7 +48,6 @@
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
@@ -56,6 +55,7 @@
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:reduce_util",
],
)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index d6dedc2..73a556f 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -24,6 +24,7 @@
import six
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -35,7 +36,6 @@
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import distribution_strategy_context as distribute_ctx
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training import slot_creator
@@ -447,7 +447,7 @@
if v is None:
if colocate_with is None:
colocate_with = self._non_slot_devices
- with self._distribution.colocate_vars_with(colocate_with):
+ with self._distribution.extended.colocate_vars_with(colocate_with):
# TODO(josh11b): Use get_variable() except for the legacy Adam use case.
v = variable_scope.variable(initial_value, name=name, trainable=False)
self._non_slot_dict[name] = v
@@ -658,7 +658,6 @@
var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
- colocate_gradients_with_ops=False,
name=None,
grad_loss=None,
stop_gradients=None,
@@ -681,8 +680,6 @@
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with the
- corresponding op.
name: Optional name for the returned operation.
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
stop_gradients: Optional. A Tensor or list of tensors not to differentiate
@@ -705,8 +702,8 @@
Minimization (and gradient computation) is done with respect to the
elements of `var_list` if not None, else with respect to any trainable
variables created during the execution of the `loss` function.
- `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
- `grad_loss` are ignored when eager execution is enabled.
+ `gate_gradients`, `aggregation_method`, and `grad_loss` are ignored when
+ eager execution is enabled.
@end_compatibility
"""
grads_and_vars = self.compute_gradients(
@@ -714,7 +711,6 @@
var_list=var_list,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
grad_loss=grad_loss,
stop_gradients=stop_gradients,
scale_loss_by_num_replicas=scale_loss_by_num_replicas)
@@ -734,7 +730,6 @@
var_list=None,
gate_gradients=GATE_OP,
aggregation_method=None,
- colocate_gradients_with_ops=False,
grad_loss=None,
stop_gradients=None,
scale_loss_by_num_replicas=None):
@@ -757,8 +752,6 @@
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with the
- corresponding op.
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
stop_gradients: Optional. A Tensor or list of tensors not to differentiate
through.
@@ -777,8 +770,8 @@
not callable.
@compatibility(eager)
- When eager execution is enabled, `gate_gradients`, `aggregation_method`,
- and `colocate_gradients_with_ops` are ignored.
+ When eager execution is enabled, `gate_gradients`, and `aggregation_method`
+ are ignored.
@end_compatibility
"""
# TODO(josh11b): Test that we handle weight decay in a reasonable way.
@@ -833,7 +826,6 @@
grad_ys=grad_loss,
gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
stop_gradients=stop_gradients)
if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
grads = control_flow_ops.tuple(grads)
@@ -928,7 +920,7 @@
def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
"""`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce(
+ reduced_grads = distribution.extended.batch_reduce_to(
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
@@ -945,7 +937,7 @@
with ops.name_scope(name, self._name) as name:
per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
# Include the current value of any dynamic hyper parameters in `state`.
- non_slot_devices = distribution.non_slot_devices(var_list)
+ non_slot_devices = distribution.extended.non_slot_devices(var_list)
state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
self._hyper, distribution, non_slot_devices)
@@ -990,7 +982,8 @@
# Use the processors to update the variables.
update_ops = []
for grad, var in grads_and_vars:
- update_ops.extend(distribution.update(var, update, grad, grouped=False))
+ update_ops.extend(distribution.extended.update(
+ var, update, args=(grad,), group=False))
# Give the child class a chance to do something after applying
# gradients
@@ -1002,8 +995,8 @@
update_ops = control_flow_ops.group(update_ops)
with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, grouped=False)
+ finish_updates = distribution.extended.update_non_slot(
+ non_slot_devices, finish, group=False)
# We said grouped=False, which means finish_updates is always a list.
# It will be [None] when finish() returns None.
if finish_updates == [None]:
@@ -1018,8 +1011,8 @@
def update_global_step(global_step, name):
return global_step.assign_add(1, read_value=False, name=name)
- apply_updates = distribution.update(global_step, update_global_step,
- name)
+ apply_updates = distribution.extended.update(
+ global_step, update_global_step, args=(name,))
# Add the training op to the TRAIN_OP graph collection in graph mode.
if not eager_execution:
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 6f65934..8619708 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -138,7 +138,7 @@
if per_channel:
if input_dim >= 2:
batch_min = math_ops.reduce_min(
- inputs, reduction_indices=reduce_dims, name='BatchMin')
+ inputs, axis=reduce_dims, name='BatchMin')
else:
batch_min = inputs
else:
@@ -147,7 +147,7 @@
if per_channel:
if input_dim >= 2:
batch_max = math_ops.reduce_max(
- inputs, reduction_indices=reduce_dims, name='BatchMax')
+ inputs, axis=reduce_dims, name='BatchMax')
else:
batch_max = inputs
else:
@@ -263,7 +263,7 @@
if per_channel:
if input_dim >= 2:
batch_min = math_ops.reduce_min(
- inputs, reduction_indices=reduce_dims, name='BatchMin')
+ inputs, axis=reduce_dims, name='BatchMin')
else:
batch_min = inputs
else:
@@ -272,7 +272,7 @@
if per_channel:
if input_dim >= 2:
batch_max = math_ops.reduce_max(
- inputs, reduction_indices=reduce_dims, name='BatchMax')
+ inputs, axis=reduce_dims, name='BatchMax')
else:
batch_max = inputs
else:
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 338923f..21d1b12 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -160,7 +160,7 @@
# shouldn't quantize it, since the activation will be Fused into the
# Add at inference time.
consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op)
- if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
+ if any(consumer.type in _ACTIVATION_TYPES for consumer in consumers):
logging.info('Skipping %s, because its followed by an activation.',
layer_match.bypass_op.name)
else:
@@ -195,7 +195,7 @@
# Add at inference time.
consumers = input_to_ops_map.ConsumerOperations(
layer_match.post_activation_bypass_op)
- if any([consumer.type in _RELU_TYPES for consumer in consumers]):
+ if any(consumer.type in _RELU_TYPES for consumer in consumers):
logging.info('Skipping %s, because its followed by an activation.',
layer_match.post_activation_bypass_op.name)
else:
diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD
index 38fcca0..bbf1099 100644
--- a/tensorflow/contrib/resampler/BUILD
+++ b/tensorflow/contrib/resampler/BUILD
@@ -13,6 +13,7 @@
)
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
tf_custom_op_py_library(
name = "resampler_py",
@@ -50,10 +51,14 @@
prefix = "resampler_ops",
deps = [
":resampler_ops_op_lib",
- "//tensorflow/compiler/tf2xla/kernels:resampler_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
- ],
+ ] + select({
+ "//tensorflow:with_xla_support": [
+ "//tensorflow/compiler/tf2xla/kernels:resampler_ops",
+ ],
+ "//conditions:default": [],
+ }),
alwayslink = 1,
)
@@ -94,3 +99,26 @@
"//tensorflow/python:array_ops",
],
)
+
+tf_xla_py_test(
+ name = "resampler_ops_xla_test",
+ size = "small",
+ srcs = ["xla/resampler_ops_xla_test.py"],
+ disabled_backends = [
+ # TODO(b/74459949) Support BatchDot in CPU backend.
+ "cpu",
+ "cpu_ondemand",
+ ],
+ # TODO(b/112295522): the OSS build will not likely work in the short to medium term, currently it is blocked by the fact that bazel does not allow py_library to depend on cc_library: https://github.com/bazelbuild/bazel/issues/701 which may not be resolvable.
+ tags = ["no_oss"],
+ deps = [
+ "//tensorflow/compiler/tests:xla_test",
+ "//tensorflow/compiler/tf2xla/kernels:resampler_ops",
+ "//tensorflow/contrib/resampler:resampler_ops",
+ "//tensorflow/contrib/resampler:resampler_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/compiler/tests/resampler_ops_test.py b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py
similarity index 76%
rename from tensorflow/compiler/tests/resampler_ops_test.py
rename to tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py
index f87ac33..d8ca0ea 100644
--- a/tensorflow/compiler/tests/resampler_ops_test.py
+++ b/tensorflow/contrib/resampler/xla/resampler_ops_xla_test.py
@@ -63,8 +63,8 @@
def testSimple(self):
for dtype in self.float_types:
input_shape = [1, 2, 2, 1]
- input_rgb_data = [0, 5, 13, 54]
- input_np = np.array(input_rgb_data, dtype=dtype).reshape(input_shape)
+ input_data = [0, 5, 13, 54]
+ input_np = np.array(input_data, dtype=dtype).reshape(input_shape)
warp_shape = [1, 2]
warp_data = [0.7, 0.6]
@@ -151,6 +151,55 @@
expected_grad_data,
expected_grad_warp)
+ def testOutOfBoundWarps(self):
+ # (x, y) are both less than 0.
+ for dtype in self.float_types:
+ input_shape = [1, 2, 2, 1]
+ input_data = [10, 5, 13, 54]
+ input_np = np.array(input_data, dtype=dtype).reshape(input_shape)
+
+ warp_shape = [1, 2, 2]
+ warp_data = [-1, -1, 0.7, 0.6]
+ warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
+ expected = [[[0.0], [27.62]]]
+ self._assertForwardOpMatchesExpected(input_np, warp_np, expected)
+
+ # One of (x, y) is less than 0.
+ for dtype in self.float_types:
+ input_shape = [1, 2, 2, 1]
+ input_data = [10, 5, 13, 54]
+ input_np = np.array(input_data, dtype=dtype).reshape(input_shape)
+
+ warp_shape = [1, 2, 2]
+ warp_data = [-1, 0.1, 0.7, 0.6]
+ warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
+ expected = [[[0.0], [27.62]]]
+ self._assertForwardOpMatchesExpected(input_np, warp_np, expected)
+
+ # Both of (x, y) are greater than image size.
+ for dtype in self.float_types:
+ input_shape = [1, 2, 2, 1]
+ input_data = [10, 5, 13, 54]
+ input_np = np.array(input_data, dtype=dtype).reshape(input_shape)
+
+ warp_shape = [1, 2, 2]
+ warp_data = [-0.1, 0.1, 1.2, 2.1]
+ warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
+ expected = [[[0.0], [0.0]]]
+ self._assertForwardOpMatchesExpected(input_np, warp_np, expected)
+
+ # One of (x, y) is greater than image size.
+ for dtype in self.float_types:
+ input_shape = [1, 2, 2, 1]
+ input_data = [10, 5, 13, 54]
+ input_np = np.array(input_data, dtype=dtype).reshape(input_shape)
+
+ warp_shape = [1, 2, 2]
+ warp_data = [0.1, -0.1, 1.2, 0.1]
+ warp_np = np.array(warp_data, dtype=dtype).reshape(warp_shape)
+ expected = [[[0.0], [0.0]]]
+ self._assertForwardOpMatchesExpected(input_np, warp_np, expected)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 245fa68..7d57b04 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -906,7 +906,7 @@
def testDropoutWrapperKeepNoOutput(self):
keep_all = variable_scope.get_variable("all", initializer=1.0)
- keep_none = variable_scope.get_variable("none", initializer=1e-10)
+ keep_none = variable_scope.get_variable("none", initializer=1e-6)
res = self._testDropoutWrapper(
input_keep_prob=keep_all,
output_keep_prob=keep_none,
@@ -922,7 +922,7 @@
def testDropoutWrapperKeepNoStateExceptLSTMCellMemory(self):
keep_all = variable_scope.get_variable("all", initializer=1.0)
- keep_none = variable_scope.get_variable("none", initializer=1e-10)
+ keep_none = variable_scope.get_variable("none", initializer=1e-6)
# Even though we dropout state, by default DropoutWrapper never
# drops out the memory ("c") term of an LSTMStateTuple.
res = self._testDropoutWrapper(
@@ -943,7 +943,7 @@
def testDropoutWrapperKeepNoInput(self):
keep_all = variable_scope.get_variable("all", initializer=1.0)
- keep_none = variable_scope.get_variable("none", initializer=1e-10)
+ keep_none = variable_scope.get_variable("none", initializer=1e-6)
true_full_output = np.array(
[[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
dtype=np.float32)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index 5cba54d..ef372b9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -227,7 +227,7 @@
def testDropout(self):
cell = Plus1RNNCell()
full_dropout_cell = rnn_cell.DropoutWrapper(
- cell, input_keep_prob=1e-12, seed=0)
+ cell, input_keep_prob=1e-6, seed=0)
(name, dep), = full_dropout_cell._checkpoint_dependencies
self.assertIs(dep, cell)
self.assertEqual("cell", name)
diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py
index b30ca78..251a933 100644
--- a/tensorflow/contrib/rnn/python/ops/gru_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py
@@ -21,7 +21,7 @@
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.layers import base as base_layer
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -165,7 +165,7 @@
num_units = cell_size
self._cell_size = num_units
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
@property
def state_size(self):
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 4db431f..b043026 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -25,6 +25,7 @@
from tensorflow.contrib.util import loader
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@@ -385,7 +386,7 @@
"scope": "lstm_cell"
}
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
@property
def state_size(self):
@@ -628,7 +629,7 @@
self._use_peephole = use_peephole
# Inputs must be 3-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=3)
+ self.input_spec = input_spec.InputSpec(ndim=3)
@property
def num_units(self):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index e159dc9..8a1c09f 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -30,7 +30,7 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
-from tensorflow.python.layers import base as base_layer
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_array_ops
@@ -2752,7 +2752,7 @@
self._activation = activation or math_ops.tanh
# Restrict inputs to be 2-dimensional matrices
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
@property
def state_size(self):
@@ -3089,7 +3089,7 @@
super(IndRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self._num_units = num_units
self._activation = activation or math_ops.tanh
@@ -3183,7 +3183,7 @@
super(IndyGRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self._num_units = num_units
self._activation = activation or math_ops.tanh
@@ -3323,7 +3323,7 @@
super(IndyLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self._num_units = num_units
self._forget_bias = forget_bias
@@ -3444,7 +3444,7 @@
super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self.units = units
self.activation = activations.get(activation)
@@ -3558,7 +3558,7 @@
super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self.units = units
self.activation = activations.get(activation)
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 8668c67..922f21b 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -154,8 +154,8 @@
if attention_layer_sizes is not None:
# Compute sum of attention_layer_sizes. Use encoder_output_depth if None.
- attention_depth = sum([attention_layer_size or encoder_output_depth
- for attention_layer_size in attention_layer_sizes])
+ attention_depth = sum(attention_layer_size or encoder_output_depth
+ for attention_layer_size in attention_layer_sizes)
elif attention_layers is not None:
# Compute sum of attention_layers output depth.
attention_depth = sum(
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index 605625c..42898e7 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -79,7 +79,6 @@
from tensorflow.python.ops.summary_ops_v2 import import_event
from tensorflow.python.ops.summary_ops_v2 import initialize
from tensorflow.python.ops.summary_ops_v2 import never_record_summaries
-from tensorflow.python.ops.summary_ops_v2 import record_summaries
from tensorflow.python.ops.summary_ops_v2 import record_summaries_every_n_global_steps
from tensorflow.python.ops.summary_ops_v2 import scalar
from tensorflow.python.ops.summary_ops_v2 import should_record_summaries
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 20bcd24..784acce 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -29,6 +29,10 @@
"if_tensorrt",
)
+exports_files(glob([
+ "test/testdata/*",
+]))
+
tf_cuda_cc_test(
name = "tensorrt_test_cc",
size = "small",
@@ -491,6 +495,7 @@
"test/memory_alignment_test.py",
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
+ "test/quantization_test.py",
"test/rank_two_test.py",
"test/reshape_transpose_test.py",
"test/vgg_block_nchw_test.py",
@@ -527,6 +532,30 @@
],
)
+cuda_py_test(
+ name = "quantization_mnist_test",
+ srcs = ["test/quantization_mnist_test.py"],
+ additional_deps = [
+ ":tf_trt_integration_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/keras:keras",
+ "//tensorflow/python/estimator:estimator",
+ ],
+ data = [
+ "test/testdata/checkpoint",
+ "test/testdata/model.ckpt-46900.data-00000-of-00001",
+ "test/testdata/model.ckpt-46900.index",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_pip",
+ "no_tap", # It is not able to download the mnist data.
+ "no_windows",
+ "nomac",
+ ],
+)
+
cc_library(
name = "utils",
srcs = ["convert/utils.cc"],
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index f95ffe4..21f505b 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -82,60 +82,73 @@
}
TrtCandidateSelector::TrtCandidateSelector(
- const grappler::GraphProperties& graph_properties)
- : graph_properties_(graph_properties) {}
+ const grappler::GraphProperties& graph_properties, int precision_mode)
+ : graph_properties_(graph_properties), precision_mode_(precision_mode) {}
Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
// TODO(laigd): move this set to TrtNodeValidator where it should belong.
// LINT.IfChange
static const std::set<string> candidate_ops = {
- "Identity",
- "Snapshot",
- "Const",
- "Conv2D",
- "MaxPool",
- "BiasAdd",
- "Relu",
- "Add",
- "Mul",
- "Sub",
- "Rsqrt",
- "Pad",
- "Mean",
- "AvgPool",
- "ConcatV2",
- "DepthwiseConv2dNative",
- "FusedBatchNorm",
- "FusedBatchNormV2",
- "Div",
- "RealDiv",
- "Rsqrt",
- "Reciprocal",
- "Exp",
- "Log",
- "Sqrt",
- "Abs",
- "Neg",
- "Transpose",
- "Reshape",
- "MatMul",
- "BatchMatMul",
- "Softmax",
- "Minimum",
- "Maximum",
- "TopKV2",
- "Sum",
- "Prod",
- "Max",
- "Min",
+ "Identity",
+ "Snapshot",
+ "Const",
+ "Conv2D",
+ "MaxPool",
+ "BiasAdd",
+ "Relu",
+ "Add",
+ "Mul",
+ "Sub",
+ "Rsqrt",
+ "Pad",
+ "Mean",
+ "AvgPool",
+ "ConcatV2",
+ "DepthwiseConv2dNative",
+ "FusedBatchNorm",
+ "FusedBatchNormV2",
+ "Div",
+ "RealDiv",
+ "Rsqrt",
+ "Reciprocal",
+ "Exp",
+ "Log",
+ "Sqrt",
+ "Abs",
+ "Neg",
+ "Transpose",
+ "Reshape",
+ "MatMul",
+ "BatchMatMul",
+ "Softmax",
+ "Minimum",
+ "Maximum",
+ "TopKV2",
+ "Sum",
+ "Prod",
+ "Max",
+ "Min",
+ "Relu6",
};
- // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc)
- const bool is_supported_op_type =
+ bool is_supported_op_type =
(candidate_ops.count(node->type_string()) ||
PluginFactoryTensorRT::GetInstance()->IsPlugin(node->type_string()));
+ static const std::set<string> quantize_ops = {
+ "QuantizeAndDequantizeV2",
+ "QuantizeAndDequantizeV3",
+ "FakeQuantWithMinMaxVars",
+ "FakeQuantWithMinMaxArgs",
+ };
+ // In INT8 mode, we will always apply the quantization ranges provided by
+ // these ops to the relevant tensors. This happens regardless of the value of
+ // use_calibration.
+ if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) {
+ is_supported_op_type = true;
+ }
+ // LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.cc)
if (!is_supported_op_type) {
return errors::Unimplemented("Op type ", node->type_string(),
- " is not supported.");
+ " is not supported");
}
std::vector<const Edge*> input_edges;
@@ -220,7 +233,8 @@
const std::vector<string>& output_names, size_t max_batch_size,
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
int precision_mode, int minimum_segment_size, bool is_dyn_op,
- int max_cached_engines, std::vector<int> cached_engine_batches) {
+ int max_cached_engines, std::vector<int> cached_engine_batches,
+ bool use_calibration) {
// Create GrapplerItem.
tensorflow::grappler::GrapplerItem item;
item.fetch = output_names;
@@ -287,6 +301,7 @@
list->add_i(batch);
}
}
+ parameters["use_calibration"].set_b(use_calibration);
// Run optimizer.
tensorflow::grappler::MetaOptimizer meta_opt(nullptr, config_proto);
@@ -566,27 +581,30 @@
}
}
}
+
+ const bool calibrate_int8 =
+ (info.precision_mode == INT8MODE && info.use_calibration);
+ // Build the engine and get its serialized representation.
string segment_string;
- if (info.engine_type == EngineInfo::EngineType::TRTStatic ||
- info.precision_mode == INT8MODE) {
+ if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) {
// Create static engine for fp32/fp16 mode, and test validity of the engine
- // for int8 mode. We don't want engine to fail at the calibration time.
- // So we are constructing a FP32 engine here to check its validity, and if
- // it is a valid engine then we put the serialized graphdef to the op.
- // Otherwise we skip node creation for this engine.
+ // for int8 calibration mode. We don't want engine to fail at the
+ // calibration time. So we are constructing a FP32 engine here to check its
+ // validity, and if it is a valid engine then we put the serialized graphdef
+ // to the op. Otherwise we skip node creation for this engine.
Logger trt_logger;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
// TODO(sami): What happens if 1st dim is not batch?
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
- info.segment_graph_def,
- info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode,
+ info.segment_graph_def, calibrate_int8 ? FP32MODE : info.precision_mode,
max_batch_size, info.max_workspace_size_bytes, input_shapes,
&trt_logger, alloc, /*calibrator=*/nullptr, &engine,
+ info.use_calibration,
/*convert_successfully=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string =
string((const char*)engine_data->data(), engine_data->size());
- if (info.precision_mode == INT8MODE) {
+ if (calibrate_int8) {
// See above comment about why not putting this inside the 'else' branch.
segment_string = info.segment_graph_def.SerializeAsString();
}
@@ -598,7 +616,7 @@
// conversion.
string prec_string;
TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string));
- if (info.precision_mode == INT8MODE &&
+ if (info.precision_mode == INT8MODE && calibrate_int8 &&
!TRTResourceManager::instance()->getManager("TRTCalibration")) {
LOG(ERROR) << "Failed to construct calibration storage";
}
@@ -634,6 +652,7 @@
.Attr("cached_engine_batches", {max_batch_size})
.Attr("workspace_size_bytes", info.max_workspace_size_bytes)
.Attr("precision_mode", prec_string)
+ .Attr("use_calibration", info.use_calibration)
.Attr("OutT", out_types)
.Finalize(&trt_node);
if (!status.ok()) {
@@ -866,7 +885,8 @@
}
segment_options.minimum_segment_size = params.minimum_segment_size;
tensorflow::tensorrt::segment::SegmentNodesVector initial_segments;
- TrtCandidateSelector candidate_selector(*params.graph_properties);
+ TrtCandidateSelector candidate_selector(*params.graph_properties,
+ params.precision_mode);
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
&graph,
std::bind(&TrtCandidateSelector::IsTensorRTCandidate, &candidate_selector,
@@ -904,10 +924,14 @@
continue;
}
curr_engine.precision_mode = params.precision_mode;
- curr_engine.engine_type =
- (params.is_dyn_op || params.precision_mode == INT8MODE
- ? EngineInfo::EngineType::TRTDynamic
- : EngineInfo::EngineType::TRTStatic);
+ if (params.use_calibration && params.precision_mode != INT8MODE) {
+ return errors::InvalidArgument(
+ "Calibration with FP32 or FP16 is not supported.");
+ }
+ curr_engine.engine_type = ((params.is_dyn_op || params.use_calibration)
+ ? EngineInfo::EngineType::TRTDynamic
+ : EngineInfo::EngineType::TRTStatic);
+ curr_engine.use_calibration = params.use_calibration;
curr_engine.cached_engine_batches = params.cached_engine_batches;
curr_engine.maximum_cached_engines = params.max_cached_engines;
StrAppend(&curr_engine.engine_name, "my_trt_op_", t);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index 1c9d821..1f39f56 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -35,7 +35,8 @@
// supported by TRT.
class TrtCandidateSelector {
public:
- TrtCandidateSelector(const grappler::GraphProperties& graph_properties);
+ TrtCandidateSelector(const grappler::GraphProperties& graph_properties,
+ int precision_mode);
// Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
// to TRT subgraph and later converted into TRT engine.
@@ -49,6 +50,9 @@
// GraphProperties of the graph whose nodes are to be validated by
// IsTensorRTCandidate().
const grappler::GraphProperties& graph_properties_;
+
+ // Quantization ops are only converted when using quantized precisions.
+ const int precision_mode_;
};
struct ConversionParams {
@@ -63,6 +67,7 @@
cluster(nullptr),
is_dyn_op(false),
fixed_input_size(true),
+ use_calibration(true),
max_cached_engines(1) {}
const tensorflow::GraphDef* input_graph_def;
const std::vector<string>* output_names;
@@ -76,6 +81,7 @@
bool is_dyn_op; // Whether to create engine on conversion or execution time
bool fixed_input_size; // Assume non-batch ranks of input tensors are fixed
int max_cached_engines; // maximum number of cached engines
+ bool use_calibration;
std::vector<int> cached_engine_batches; // list of cached engines
};
@@ -95,7 +101,7 @@
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
int precision_mode = 1, int minimum_segment_size = 3,
bool is_dyn_op = false, int max_cached_engines = 1,
- std::vector<int> cached_engine_batches = {});
+ std::vector<int> cached_engine_batches = {}, bool use_calibration = true);
// Method to call from optimization pass
tensorflow::Status ConvertAfterShapes(ConversionParams& params);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc
index f107299..2d2bfeb 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc
@@ -85,27 +85,42 @@
ops::MatMul(s.WithOpName("matmul_with_incompatible_input"),
incompatible_feed, const_2);
+ // Quantize ops.
+ auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f);
+ auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("quantize"), feed,
+ quantize_attrs);
+
+ // Get GrapplerItem and GraphProperties.
grappler::GrapplerItem item;
TF_EXPECT_OK(s.ToGraphDef(&item.graph));
Tensor feed_tensor(DT_FLOAT, input_shape);
item.feed.push_back(std::make_pair("feed", feed_tensor));
-
grappler::GraphProperties graph_properties(item);
TF_EXPECT_OK(graph_properties.InferStatically(true));
- TrtCandidateSelector selector(graph_properties);
- TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node()));
- ExpectStatus(
- selector.IsTensorRTCandidate(incompatible_matmul.operation.node()),
- error::INVALID_ARGUMENT,
- "transpose_a is not supported for TensorRT FullyConnected "
- "(op: MatMul), at: incompatible_matmul");
- ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()),
- error::UNIMPLEMENTED, "Op type Sin is not supported");
- ExpectStatus(selector.IsTensorRTCandidate(
- matmul_with_incompatible_input.operation.node()),
- error::INTERNAL,
- "Failed to convert input with index 0 to a TRT_TensorOrWeights");
+ for (const int precision_mode : {FP32MODE, INT8MODE}) {
+ TrtCandidateSelector selector(graph_properties, precision_mode);
+ TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node()));
+ ExpectStatus(
+ selector.IsTensorRTCandidate(incompatible_matmul.operation.node()),
+ error::INVALID_ARGUMENT,
+ "transpose_a is not supported for TensorRT FullyConnected "
+ "(op: MatMul), at: incompatible_matmul");
+ ExpectStatus(selector.IsTensorRTCandidate(unsupported_op.operation.node()),
+ error::UNIMPLEMENTED, "Op type Sin is not supported");
+ ExpectStatus(
+ selector.IsTensorRTCandidate(
+ matmul_with_incompatible_input.operation.node()),
+ error::INTERNAL,
+ "Failed to convert input with index 0 to a TRT_TensorOrWeights");
+ if (precision_mode == INT8MODE) {
+ TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node()));
+ } else {
+ ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()),
+ error::UNIMPLEMENTED,
+ "Op type FakeQuantWithMinMaxArgs is not supported");
+ }
+ }
}
class FakeCluster : public grappler::Cluster {
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index af9bbbf..cb2a1ca 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -202,6 +202,21 @@
}
}
+string DebugString(const nvinfer1::DataType trt_dtype) {
+ switch (trt_dtype) {
+ case nvinfer1::DataType::kFLOAT:
+ return "kFLOAT";
+ case nvinfer1::DataType::kHALF:
+ return "kHALF";
+ case nvinfer1::DataType::kINT8:
+ return "kINT8";
+ case nvinfer1::DataType::kINT32:
+ return "kINT32";
+ default:
+ return "Invalid TRT data type";
+ }
+}
+
string DebugString(const nvinfer1::Dims& dims) {
string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
for (int i = 0; i < dims.nbDims; ++i) {
@@ -222,16 +237,15 @@
string DebugString(const nvinfer1::ITensor& tensor) {
return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
- ", shape=", DebugString(tensor.getDimensions()), ")");
+ ", name=", tensor.getName(),
+ ", dtype=", DebugString(tensor.getType()),
+ ", dims=", DebugString(tensor.getDimensions()), ")");
}
-// Return whether or not the broadcast is feasible;
-bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l,
- const bool operand_l_is_tensor,
- const nvinfer1::Dims& operand_r,
- const bool operand_r_is_tensor,
- nvinfer1::Dims* operand_l_new_shape,
- nvinfer1::Dims* operand_r_new_shape) {
+Status Converter::GetTrtBroadcastShape(
+ const TRT_TensorOrWeights& operand_l, const TRT_TensorOrWeights& operand_r,
+ nvinfer1::Dims* operand_l_new_dims,
+ nvinfer1::Dims* operand_r_new_dims) const {
// ***************************************************************************
// TensorRT Elementwise op supports broadcast but requires both tensor to be
// of Identical rank
@@ -256,52 +270,59 @@
// -> T: 1 1 1 -1 3 5 1
// -> W: 1 1 1 1 3 5 1
// ***************************************************************************
+ if (!operand_l.is_tensor() && !operand_r.is_tensor()) {
+ return errors::InvalidArgument(
+ "Broadcasting requires at least one of the operands be tensors");
+ }
+
const int max_nb_dims = nvinfer1::Dims::MAX_DIMS + 1;
- const size_t element_size = sizeof(operand_l.d[0]);
+ auto compute_output_dims =
+ [max_nb_dims](const TRT_TensorOrWeights& input, int broadcast_num_dims,
+ int* output_dims_array, nvinfer1::Dims* output_dims) {
+ const nvinfer1::Dims input_dims = input.GetTrtDims();
+ std::fill(output_dims_array, output_dims_array + max_nb_dims, 1);
+ std::copy(input_dims.d, input_dims.d + input_dims.nbDims,
+ output_dims_array + broadcast_num_dims - input_dims.nbDims);
+ if (input.is_tensor()) {
+ const int true_input_dims = input_dims.nbDims + 1;
+ if (true_input_dims < broadcast_num_dims) {
+ return errors::InvalidArgument(
+ "Broadcasting beyond batch dimension is not supported ",
+ "(tensor #dims ", true_input_dims, " vs broadcast #dims ",
+ broadcast_num_dims, ")");
+ }
+ // Set the batch dimension to -1, since batch size is not supposed to
+ // be broadcasted.
+ output_dims_array[0] = -1;
+ }
+ // Copy to output dimensions (stripping the batch dimension).
+ output_dims->nbDims = broadcast_num_dims - 1;
+ std::copy(output_dims_array + 1, output_dims_array + broadcast_num_dims,
+ output_dims->d);
+ return Status::OK();
+ };
- // fill in dimensions
- int l_s[max_nb_dims];
- std::fill(l_s, l_s + max_nb_dims, 1);
- int l_d = operand_l_is_tensor ? operand_l.nbDims + 1 : operand_l.nbDims;
- int r_s[max_nb_dims];
- std::fill(r_s, r_s + max_nb_dims, 1);
- int r_d = operand_r_is_tensor ? operand_r.nbDims + 1 : operand_r.nbDims;
+ // Compute the output dimensions.
+ const int broadcast_num_dims =
+ std::max(operand_l.GetTrtDims().nbDims + (operand_l.is_tensor() ? 1 : 0),
+ operand_r.GetTrtDims().nbDims + (operand_r.is_tensor() ? 1 : 0));
+ int output_l[max_nb_dims], output_r[max_nb_dims];
+ TF_RETURN_IF_ERROR(compute_output_dims(operand_l, broadcast_num_dims,
+ output_l, operand_l_new_dims));
+ TF_RETURN_IF_ERROR(compute_output_dims(operand_r, broadcast_num_dims,
+ output_r, operand_r_new_dims));
- int max_d = std::max(l_d, r_d);
- std::memcpy(l_s + max_d - operand_l.nbDims, operand_l.d,
- operand_l.nbDims * element_size);
- std::memcpy(r_s + max_d - operand_r.nbDims, operand_r.d,
- operand_r.nbDims * element_size);
-
- // set -1 for batch dimension, since batch size is not supposed to be
- // broadcasted
- if (operand_l_is_tensor) {
- if (max_d != l_d) { // if broadcast beyond batch dimension, fail
- return false;
- }
- l_s[0] = -1;
- }
- if (operand_r_is_tensor) {
- if (max_d != r_d) { // if broadcast beyond batch dimension, fail
- return false;
- }
- r_s[0] = -1;
- }
-
- // compare broadcast feasibility
- for (int i = max_d - 1; i >= 0; i--) {
- if ((l_s[i] != r_s[i]) && (l_s[i] != 1) && (r_s[i] != 1)) {
- return false;
+ // Compare broadcast feasibility
+ for (int i = 0; i < broadcast_num_dims; ++i) {
+ if ((output_l[i] != output_r[i]) && (output_l[i] != 1) &&
+ (output_r[i] != 1)) {
+ return errors::InvalidArgument(
+ "Infeasible broadcast scheme (", "batch_dim: ", output_l[0], ", ",
+ DebugString(*operand_l_new_dims), " vs ", "batch_dim: ", output_r[0],
+ ", ", DebugString(*operand_r_new_dims), ")");
}
}
-
- // output new TensorRT Dimension (stripping the batch dimension)
- operand_l_new_shape->nbDims = max_d - 1;
- std::memcpy(operand_l_new_shape->d, l_s + 1, (max_d - 1) * element_size);
- operand_r_new_shape->nbDims = max_d - 1;
- std::memcpy(operand_r_new_shape->d, r_s + 1, (max_d - 1) * element_size);
-
- return true;
+ return Status::OK();
}
inline bool DimsEqual(const nvinfer1::Dims& dim_l,
@@ -449,7 +470,9 @@
void setLocation(nvinfer1::TensorLocation location) override {}
#if NV_TENSORRT_MAJOR >= 5
- bool setDynamicRange(float min, float max) override {}
+ bool setDynamicRange(float min, float max) override { return true; }
+
+ float getDynamicRange() const override { return 0; }
#endif
private:
@@ -513,8 +536,7 @@
string TRT_TensorOrWeights::DebugString() const {
string output = "TRT_TensorOrWeights(type=";
if (is_tensor()) {
- StrAppend(&output, "tensor @", reinterpret_cast<uintptr_t>(tensor()),
- ", shape=", convert::DebugString(tensor()->getDimensions()),
+ StrAppend(&output, "tensor=", convert::DebugString(*tensor()),
", batch_size=", batch_size_);
} else {
StrAppend(&output, "weights=", weights_.DebugString());
@@ -777,8 +799,9 @@
Status status = ConvertToTensorOrWeights(
*pair.first, pair.second, graph_properties, &tensor_or_weights);
if (!status.ok()) {
- return errors::Internal("Failed to convert input with index ", i,
- " to a TRT_TensorOrWeights");
+ return errors::Internal(
+ "Failed to convert input with index ", i,
+ " to a TRT_TensorOrWeights: ", status.error_message());
}
inputs.push_back(tensor_or_weights);
}
@@ -810,8 +833,11 @@
return status;
}
-Converter::Converter(nvinfer1::INetworkDefinition* trt_network, bool is_fp16)
- : trt_network_(trt_network), is_fp16_(is_fp16) {
+Converter::Converter(nvinfer1::INetworkDefinition* trt_network,
+ int precision_mode, bool use_calibration)
+ : trt_network_(trt_network),
+ precision_mode_(precision_mode),
+ use_calibration_(use_calibration) {
this->RegisterOpConverters();
}
@@ -836,13 +862,18 @@
TRT_TensorOrWeights& output = outputs[i];
string output_name = node_def.name();
if (i != 0) output_name = StrCat(output_name, ":", i);
- // We need to check the name before setting it. For Identity op where the
- // output is the input, if its input is one of the engine input, setting
- // the name here will overwrite engine input bindings which will cause
- // runtime error.
+ // We need to check the name before setting it. If the input is one of the
+ // engine input, setting the name here will overwrite engine input
+ // bindings which will cause runtime error.
if (output.is_tensor()) {
const char* tensor_name = output.tensor()->getName();
- if (tensor_name == nullptr || std::strlen(tensor_name) == 0) {
+ if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) {
+ // TRT initializes tensor names as "(Unnamed ITensor* N)". We rename
+ // them to match their corresponding TensorFlow name.
+ // Note: ITensors that we create internally within TF-TRT which are
+ // not inputs or outputs of a node will not be renamed. This is a
+ // potential cause of confusion if an error message or warning
+ // mentions the unnamed tensor.
output.tensor()->setName(output_name.c_str());
}
}
@@ -954,6 +985,7 @@
nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Transpose");
+ MarkQuantizationRangesAsInferrable(input_tensor, layer->getOutput(0));
nvinfer1::Permutation permutation;
for (int32_t i = 0; i < dims.nbDims; ++i) {
@@ -976,6 +1008,38 @@
return tensorflow::Status::OK();
}
+Status Converter::GetWeightRange(const TRT_ShapedWeights& weights,
+ float* out_min, float* out_max) const {
+ switch (weights.type_) {
+ case DataType::DT_FLOAT: {
+ auto inp = static_cast<float const*>(weights.GetValues());
+ auto result = std::minmax_element(inp, inp + weights.count());
+ *out_min = *result.first;
+ *out_max = *result.second;
+ break;
+ }
+ case DataType::DT_HALF: {
+ auto inp = static_cast<Eigen::half const*>(weights.GetValues());
+ auto result = std::minmax_element(inp, inp + weights.count());
+ *out_min = Eigen::half_impl::half_to_float(*result.first);
+ *out_max = Eigen::half_impl::half_to_float(*result.second);
+ break;
+ }
+ case DataType::DT_INT32: {
+ auto inp = static_cast<int const*>(weights.GetValues());
+ auto result = std::minmax_element(inp, inp + weights.count());
+ *out_min = static_cast<float>(*result.first);
+ *out_max = static_cast<float>(*result.second);
+ break;
+ }
+ default:
+ return errors::Unimplemented(
+ "Data type not supported for GetWeightRange: ",
+ DataTypeString(weights.type_));
+ }
+ return Status::OK();
+}
+
Status Converter::PrepareTensorForShape(const TRT_TensorOrWeights& input,
const nvinfer1::Dims& dims,
const nvinfer1::ITensor** tensor) {
@@ -990,8 +1054,9 @@
}
if (can_check_shapes &&
TrtDimsNumElements(input.GetTrtDims()) != TrtDimsNumElements(dims)) {
- return tensorflow::errors::InvalidArgument(
- "Reshape shapes are not compatible.");
+ return errors::InvalidArgument("Reshape shapes are not compatible (",
+ DebugString(input.GetTrtDims()), " vs ",
+ DebugString(dims), ")");
}
if (input.is_tensor()) {
@@ -1002,6 +1067,8 @@
*const_cast<nvinfer1::ITensor*>(input.tensor()));
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape");
layer->setReshapeDimensions(dims);
+ MarkQuantizationRangesAsInferrable(
+ const_cast<nvinfer1::ITensor*>(input.tensor()), layer->getOutput(0));
*tensor = layer->getOutput(0);
}
} else {
@@ -1009,10 +1076,123 @@
this->network()->addConstant(dims, input.weights().GetTrtWeights());
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, "TF-TRT Internal Reshape");
*tensor = layer->getOutput(0);
+ if (precision_mode() == INT8MODE && !use_calibration()) {
+ // If we are in int8 mode and not calibrating, we need to explicitly set a
+ // quantization range for the output tensor of the IConstantLayer. Here we
+ // set the range to [min(weights), max(weights)].
+ float min_range = 0.0f;
+ float max_range = 0.0f;
+ TF_RETURN_IF_ERROR(
+ GetWeightRange(input.weights(), &min_range, &max_range));
+ // Avoid setting range to 0 because TRT will throw an error. If the
+ // weights are zero then the range doesn't matter: using 127.0f should
+ // ensure the quantized weight will be exactly zero.
+ if (min_range == 0.0f && max_range == 0.0f) {
+ min_range = -127.0f;
+ max_range = 127.0f;
+ }
+ ProvideQuantizationRange(const_cast<nvinfer1::ITensor*>(*tensor),
+ min_range, max_range);
+ }
}
return tensorflow::Status::OK();
}
+void Converter::MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input,
+ nvinfer1::ITensor* output) {
+ quantization_infer_.push_back({input, output});
+ quantization_infer_.push_back({output, input});
+}
+
+void Converter::ProvideQuantizationRange(nvinfer1::ITensor* tensor,
+ float min_range, float max_range) {
+ float symmetric_range = std::max(std::abs(min_range), std::abs(max_range));
+ quantization_ranges_[tensor] = symmetric_range;
+}
+
+void Converter::MaybeApplyQuantizationRanges() {
+ if (precision_mode() != INT8MODE) return;
+
+ // Infer ranges across marked ops.
+ PropagateQuantizationRanges();
+ // Apply ranges.
+#if NV_TENSORRT_MAJOR >= 5
+ for (auto pair : quantization_ranges_) {
+ nvinfer1::ITensor* tensor = pair.first;
+ const float range = pair.second;
+ VLOG(1) << "Setting range for: " << tensor->getName() << ": " << range;
+ // TODO(laigd): if 'tensor' already has a range set which doesn't match
+ // 'range', it should report error.
+ tensor->setDynamicRange(-range, range);
+ }
+#endif
+
+ // Warn user about tensors that are missing ranges. If TRT fuses some layers
+ // then these tensors may not actually be required, which is why this is
+ // just a warning. If we are still missing ranges even after fusion,
+ // Builder::buildCudaEngine() will return nullptr and we will catch the
+ // error at that point.
+ if (!use_calibration()) {
+ // Get all tensors from network
+ std::set<nvinfer1::ITensor*> all_tensors;
+ for (int i = 0; i < this->network()->getNbLayers(); i++) {
+ nvinfer1::ILayer* layer = this->network()->getLayer(i);
+ for (int j = 0; j < layer->getNbInputs(); j++) {
+ all_tensors.insert(layer->getInput(j));
+ }
+ for (int j = 0; j < layer->getNbOutputs(); j++) {
+ all_tensors.insert(layer->getOutput(j));
+ }
+ }
+ // Find tensors with no ranges
+ for (auto tensor : all_tensors) {
+ if (!quantization_ranges_.count(tensor)) {
+ // Note: there may be some warnings for "(Unnamed ITensor* N)". These
+ // are tensors which are created internally by TF-TRT. The ranges for
+ // these unnamed ITensors are always inferred from user provided ranges,
+ // thus there will also be a warning for the range(s) the user missed.
+ LOG(WARNING) << "Quantization range was not found for "
+ << tensor->getName() << ". "
+ << "This is okay if TensorRT does not need the range "
+ << "(e.g. due to node fusion).";
+ }
+ }
+ }
+}
+
+void Converter::PropagateQuantizationRanges() {
+ // Propagate ranges across edges in quantization_infer_ until no new
+ // information is added.
+ // Note: this function modifies quantization_infer_, it might be better to
+ // modify a copy instead if we for some reason need quantization_infer_
+ // later.
+ bool information_added = true;
+ while (information_added) {
+ information_added = false;
+ for (auto it = quantization_infer_.begin();
+ it != quantization_infer_.end();) {
+ auto input_tensor_range = quantization_ranges_.find(it->first);
+ auto output_tensor_range = quantization_ranges_.find(it->second);
+ if (input_tensor_range != quantization_ranges_.end() &&
+ output_tensor_range == quantization_ranges_.end()) {
+ // Input has range but output doesn't: copy range
+ // TODO(laigd): consider reporting error if it a different range is
+ // already set.
+ quantization_ranges_[it->second] = input_tensor_range->second;
+ information_added = true;
+ VLOG(1) << "Copy quantization range: " << it->first->getName() << " -> "
+ << it->second->getName();
+ }
+ // We can remove edges when the output range is known
+ if (quantization_ranges_.find(it->second) != quantization_ranges_.end()) {
+ it = quantization_infer_.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+}
+
Status Converter::GetInputs(const tensorflow::NodeDef& node_def,
std::vector<TRT_TensorOrWeights>* inputs) const {
for (auto const& input_name : node_def.input()) {
@@ -1069,12 +1249,11 @@
}
// ****************************************************************************
-// Constant folding functions
-// TODO(jie): once optimizer kicks in, we should have done constant folding
-// there.
+// Constant folding functions for weights.
+// TODO(laigd): we should probably use eigen directly.
// *****************************************************************************
struct LambdaFactory {
- enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB, RECIP };
+ enum class OP_CATEGORY : int { RSQRT = 0, NEG, RECIP };
OP_CATEGORY op;
template <typename T>
@@ -1089,7 +1268,7 @@
case OP_CATEGORY::RECIP:
return [](T t) -> T { return 1.0 / t; };
default:
- VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
+ LOG(ERROR) << "Not supported op for unary: " << static_cast<int>(op);
return nullptr;
}
}
@@ -1100,15 +1279,18 @@
switch (op) {
case OP_CATEGORY::RSQRT: {
VLOG(2) << "RSQRT GETS DONE";
- return [](Eigen::half t) -> Eigen::half {
+ return [](Eigen::half t) {
return Eigen::half(1.0 / sqrt(static_cast<float>(t)));
};
}
case OP_CATEGORY::NEG:
- return [](Eigen::half t) -> Eigen::half { return -t; };
- // TODO(aaroey): can we support RECIP?
+ return [](Eigen::half t) { return -t; };
+ case OP_CATEGORY::RECIP:
+ return [](Eigen::half t) {
+ return Eigen::half(1.0 / static_cast<float>(t));
+ };
default:
- VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
+ LOG(ERROR) << "Not supported op for unary: " << static_cast<int>(op);
return nullptr;
}
}
@@ -1140,50 +1322,48 @@
return tensorflow::Status::OK();
}
+// If swapped_inputs is false, 'tensor' is the left operand and 'weights' is the
+// right operand. If swapped_inputs is true, those two are swapped.
+//
// TODO(jie): broadcast is needed yet not implemented.
-// Only implemented channel wise for the time being
-tensorflow::Status BinaryTensorOpWeight(OpConverterParams* params,
- const nvinfer1::ITensor* tensor,
- TRT_ShapedWeights weights,
- bool swapped_inputs) {
+// Only implemented channel wise for the time being.
+Status BinaryTensorOpWeight(OpConverterParams* params,
+ const nvinfer1::ITensor* tensor,
+ TRT_ShapedWeights weights, bool swapped_inputs) {
+ static const std::unordered_set<string> supported_ops = {"Sub", "Add", "Mul",
+ "Div", "RealDiv"};
const auto& node_def = params->node_def;
- // tensor is the left operand while weights is the right operand;
- // when swapped_inputs set to true, those two are swapped.
- // TODO(aaroey): use a set.
- if (node_def.op() != "Sub" && node_def.op() != "Add" &&
- node_def.op() != "Mul" && node_def.op() != "Div" &&
- node_def.op() != "RealDiv") {
- return tensorflow::errors::Unimplemented(
- "op not supported: " + node_def.op() + ", at: " + node_def.name());
+ if (!supported_ops.count(node_def.op())) {
+ return errors::Unimplemented(node_def.op(), " is not supported, at ",
+ node_def.name());
}
- // Check type consistency
- nvinfer1::DataType ttype;
- TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &ttype));
+ // Check type consistency.
+ nvinfer1::DataType trt_dtype;
+ TF_RETURN_IF_ERROR(ConvertDType(weights.type_, &trt_dtype));
- // Check scale mode
+ // Check scale mode.
auto dims_w = weights.shape_;
- auto dims_t = tensor->getDimensions();
+ const auto dims_t = tensor->getDimensions();
// TODO(jie): addScale checks for input tensor dimension
if (dims_t.nbDims != 3) {
- return tensorflow::errors::InvalidArgument(
- "addScale requires tensor with rank 3, " + node_def.name());
+ return errors::InvalidArgument("addScale requires tensor with rank 3, at ",
+ node_def.name());
}
- // default to element-wise
+ // Default to element-wise
auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
// TODO(jie): maybe use a permutation instead to support more cases;
- bool permutation_flag = false;
+ bool need_to_permute = false;
if (weights.count() == 1) {
- VLOG(2) << "UNIFORM";
scale_mode = nvinfer1::ScaleMode::kUNIFORM;
} else {
- // no broadcasting on Batch dimension;
- VLOG(2) << "WEIGHTS DIM: " << dims_w.nbDims
- << " tensor DIM: " << dims_t.nbDims;
+ VLOG(2) << "weights dims: " << DebugString(dims_w)
+ << "; tensor dims: " << DebugString(dims_t);
+ // Make sure no broadcasting on batch dimension.
if (dims_w.nbDims == dims_t.nbDims + 1) {
if (dims_w.d[0] == 1) {
for (int i = 1; i < dims_w.nbDims; i++) {
@@ -1191,72 +1371,70 @@
}
dims_w.nbDims--;
} else {
- return tensorflow::errors::InvalidArgument(
- "Binary op cannot operate on batch, " + node_def.name());
+ return errors::InvalidArgument("Binary op cannot operate on batch, at ",
+ node_def.name());
}
}
if (dims_w.nbDims == dims_t.nbDims && dims_w.d[0] == dims_t.d[0]) {
scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
- // default is element;
+ // Default is element-wise
for (int i = 1; i < dims_w.nbDims; i++) {
if (dims_w.d[i] != dims_t.d[i]) {
- // if dimension does not match, switch back to channel;
- VLOG(2) << "channel";
+ // If dimension does not match, switch back to per-channel
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
break;
}
}
- // if channel as candidate, validate it
+ // If the mode is per-channel, since channel dimension is assumed to be
+ // the third to last dimension, we need to make sure all other dimensions
+ // have size 1.
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
for (int i = 1; i < dims_w.nbDims; i++) {
if (dims_w.d[i] != 1)
- return tensorflow::errors::InvalidArgument(
- "Weight shape not compatible at, " + node_def.name());
+ return errors::InvalidArgument(
+ "Weight dims not compatible for channel-wise broadcast at ",
+ node_def.name());
}
- } else {
- VLOG(2) << "elementwise";
}
} else if (dims_w.nbDims == 1 &&
dims_w.d[0] == dims_t.d[dims_t.nbDims - 1]) {
- // channel wise and broadcast required;
- permutation_flag = true;
+ // Channel wise and broadcast required. We compare the last dimension of
+ // the tensor shape because of tensorflow default broadcasting rules.
+ need_to_permute = true;
scale_mode = nvinfer1::ScaleMode::kCHANNEL;
} else {
- return tensorflow::errors::InvalidArgument(
- "Weight shape not compatible at, " + node_def.name());
+ return errors::InvalidArgument("Weight dims not compatible at ",
+ node_def.name());
}
}
+ // TODO(laigd): we should add validation_only support in TransposeTensor() and
+ // PrepareTensorForShape().
+ if (params->validation_only) return Status::OK();
- // transpose last dimension
+ // Transpose last dimension.
std::vector<int> permutation(dims_t.nbDims + 1);
- if (permutation_flag) {
- if (scale_mode == nvinfer1::ScaleMode::kCHANNEL && dims_t.nbDims > 1) {
- // we swap the last dimension into channel for trt.
- // because of tensorflow default broadcasting rules.
- for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
- permutation[i] = i;
- }
- permutation[1] = dims_t.nbDims;
- permutation[dims_t.nbDims] = 1;
- TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
- const_cast<nvinfer1::ITensor*>(tensor), permutation, &tensor));
- } else {
- return tensorflow::errors::InvalidArgument(
- "Transpose cannot be applied, " + node_def.name());
+ if (need_to_permute) {
+ // We swap the last dimension into channel for trt, because of tensorflow
+ // default broadcasting rules.
+ for (int i = 0; i < static_cast<int>(permutation.size()); i++) {
+ permutation[i] = i;
}
+ permutation[1] = dims_t.nbDims;
+ permutation[dims_t.nbDims] = 1;
+ TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
+ const_cast<nvinfer1::ITensor*>(tensor), permutation, &tensor));
}
- if (params->converter->is_fp16()) {
+ if (params->converter->precision_mode() == FP16MODE) {
weights = ConvertFP32ToFP16(params->weight_store, weights);
}
- // prepare weights
+ // Prepare weights
TRT_ShapedWeights shift_weights(weights.type_);
TRT_ShapedWeights scale_weights(weights.type_);
TRT_ShapedWeights power_weights(weights.type_);
- // Maybe I should do a switch
if (node_def.op() == "Sub") {
if (swapped_inputs) {
shift_weights = weights;
@@ -1264,6 +1442,10 @@
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::UnaryOperation::kNEG);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ // Since quantization ranges are symmetric, the same range as the input
+ // will work for the negation of the input.
+ params->converter->MarkQuantizationRangesAsInferrable(
+ const_cast<nvinfer1::ITensor*>(tensor), layer->getOutput(0));
tensor = layer->getOutput(0);
} else {
TRT_ShapedWeights neg_weights =
@@ -1275,6 +1457,25 @@
}
} else if (node_def.op() == "Div" || node_def.op() == "RealDiv") {
if (swapped_inputs) {
+ // We need to infer the quantization range for this intermediate tensor.
+ //
+ // x -> [Recip] -> 1/x -> [Scale] -> s/x
+ // ^
+ // need range for this
+ //
+ // We have the quantization scales for x and s/x - can we divide the scale
+ // for s/x by s? Only if it is a scalar.
+ //
+ // Because of this issue, fall back to BinaryTensorOpTensor if we are
+ // doing INT8 with no calibration. There is most likely no performance
+ // penalty by falling back here.
+ if (params->converter->precision_mode() == INT8MODE &&
+ !params->converter->use_calibration()) {
+ return errors::Unimplemented(
+ "Intermediate quantization range cannot be determined without"
+ " calibration. Falling back to BinaryTensorOpTensor for ",
+ node_def.op(), ", at ", node_def.name());
+ }
scale_weights = weights;
nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(
*const_cast<nvinfer1::ITensor*>(tensor),
@@ -1294,8 +1495,8 @@
} else if (node_def.op() == "Add") {
shift_weights = weights;
} else {
- return tensorflow::errors::Unimplemented("Binary op not supported: " +
- node_def.op());
+ // This should not happen.
+ return errors::Unimplemented("Binary op not supported at ", node_def.op());
}
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
@@ -1305,8 +1506,8 @@
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- // transpose back dimension
- if (permutation_flag) {
+ // Transpose back dimension
+ if (need_to_permute) {
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
const_cast<nvinfer1::ITensor*>(output_tensor), permutation,
&output_tensor));
@@ -1350,7 +1551,7 @@
return tensorflow::errors::Internal(
"Conv2D expects kernel of dimension 4, at: " + node_def.name());
}
- if (params->converter->is_fp16()) {
+ if (params->converter->precision_mode() == FP16MODE) {
weights_rsck =
ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights());
}
@@ -1397,6 +1598,8 @@
nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second));
TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
+ params->converter->MarkQuantizationRangesAsInferrable(
+ const_cast<nvinfer1::ITensor*>(tensor), pad_layer->getOutput(0));
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
VLOG(2) << "TENSOR after: " << DebugString(tensor->getDimensions());
@@ -1438,9 +1641,9 @@
params->node_def.name());
}
-tensorflow::Status BinaryTensorOpTensor(OpConverterParams* params,
- const TRT_TensorOrWeights& operand_l,
- const TRT_TensorOrWeights& operand_r) {
+Status BinaryTensorOpTensor(OpConverterParams* params,
+ const TRT_TensorOrWeights& operand_l,
+ const TRT_TensorOrWeights& operand_r) {
const auto& node_def = params->node_def;
static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
{"Add", nvinfer1::ElementWiseOperation::kSUM},
@@ -1451,50 +1654,52 @@
{"Minimum", nvinfer1::ElementWiseOperation::kMIN},
{"Maximum", nvinfer1::ElementWiseOperation::kMAX},
};
-
- const nvinfer1::ITensor* tensor_l;
- const nvinfer1::ITensor* tensor_r;
-
- nvinfer1::Dims dim_l;
- nvinfer1::Dims dim_r;
-
- if (!TensorRTGetBroadcastShape(operand_l.GetTrtDims(), operand_l.is_tensor(),
- operand_r.GetTrtDims(), operand_r.is_tensor(),
- &dim_l, &dim_r)) {
- return tensorflow::errors::InvalidArgument(
- "Binary op broadcast scheme not supported by TensorRT op: " +
- node_def.op() + ", at: " + node_def.name());
- }
-
- TF_RETURN_IF_ERROR(
- params->converter->PrepareTensorForShape(operand_l, dim_l, &tensor_l));
- TF_RETURN_IF_ERROR(
- params->converter->PrepareTensorForShape(operand_r, dim_r, &tensor_r));
-
- // get trt type & shape
- TFAttrs attrs(node_def);
- // maybe this part has to be moved into the block of rsqrt later
- nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
-
- // check type consistency
- TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype);
- TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype);
auto op_pair = ops.find(node_def.op());
if (op_pair == ops.end()) {
- return tensorflow::errors::Unimplemented(
- "binary op: ", node_def.op(), " not supported at: ", node_def.name());
+ return errors::Unimplemented("Binary op ", node_def.op(),
+ " not supported at: ", node_def.name());
}
+ nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
+ Status status = params->converter->GetTrtBroadcastShape(
+ operand_l, operand_r, &broadcasted_dims_l, &broadcasted_dims_r);
+ if (!status.ok()) {
+ return errors::InvalidArgument(
+ "Unsupported binary op broadcast scheme for op ", node_def.name(), ": ",
+ status.error_message());
+ }
+ if (params->validation_only) return Status::OK();
+
+ const nvinfer1::ITensor* tensor_l = nullptr;
+ const nvinfer1::ITensor* tensor_r = nullptr;
+ status = params->converter->PrepareTensorForShape(
+ operand_l, broadcasted_dims_l, &tensor_l);
+ if (status.ok()) {
+ status = params->converter->PrepareTensorForShape(
+ operand_r, broadcasted_dims_r, &tensor_r);
+ }
+ if (!status.ok()) {
+ return errors::Internal("Failed to convert binary op ", node_def.name(),
+ ": ", status.error_message());
+ }
+
+ // Check type consistency.
+ TFAttrs attrs(node_def);
+ nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
+ TFTRT_CHECK_EQ_TYPE(tensor_l->getType(), dtype)
+ << DebugString(tensor_l->getType()) << " vs " << DebugString(dtype);
+ TFTRT_CHECK_EQ_TYPE(tensor_r->getType(), dtype)
+ << DebugString(tensor_r->getType()) << " vs " << DebugString(dtype);
+
+ // Add ElementWise layer.
nvinfer1::IElementWiseLayer* layer =
params->converter->network()->addElementWise(
- // TODO(aaroey): will tensor_l/tensor_r get modified?
*const_cast<nvinfer1::ITensor*>(tensor_l),
*const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
-
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
- // pass the output
+ // Pass the output
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
@@ -1741,6 +1946,8 @@
nvinfer1::DimsHW(padding[0].first, padding[1].first),
nvinfer1::DimsHW(padding[0].second, padding[1].second));
TFTRT_RETURN_ERROR_IF_NULLPTR(pad_layer, node_def.name());
+ params->converter->MarkQuantizationRangesAsInferrable(
+ const_cast<nvinfer1::ITensor*>(tensor), pad_layer->getOutput(0));
padding = {{0, 0}, {0, 0}};
tensor = pad_layer->getOutput(0);
}
@@ -1748,6 +1955,11 @@
nvinfer1::IPoolingLayer* layer = params->converter->network()->addPooling(
*const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ // TODO(tmorris): Average pooling may not be entirely safe to infer
+ // quantization range through (at least forwards - backwards should be fine).
+ // Max pooling is okay.
+ params->converter->MarkQuantizationRangesAsInferrable(
+ const_cast<nvinfer1::ITensor*>(tensor), layer->getOutput(0));
layer->setStride(stride);
layer->setPadding({padding[0].first, padding[1].first});
@@ -1776,6 +1988,148 @@
return tensorflow::Status::OK();
}
+Status ConvertQuantize(OpConverterParams* params) {
+ const auto& inputs = params->inputs;
+ const auto& node_def = params->node_def;
+ if ((inputs.size() == 0) ||
+ (node_def.op() == "FakeQuantWithMinMaxArgs" && inputs.size() != 1) ||
+ (node_def.op() == "FakeQuantWithMinMaxVars" && inputs.size() != 3) ||
+ (node_def.op() == "QuantizeAndDequantizeV2" && inputs.size() != 3) ||
+ (node_def.op() == "QuantizeAndDequantizeV3" && inputs.size() != 4)) {
+ return errors::InvalidArgument("Invalid number of inputs for ",
+ node_def.op(), ", at ", node_def.name());
+ }
+ if (inputs.at(0).is_weights()) {
+ // TensorRT will automatically quantize weights, so we will ignore ranges
+ // for weights.
+ params->outputs->push_back(inputs.at(0));
+ return Status::OK();
+ }
+ float min_range = 0.0f;
+ float max_range = 0.0f;
+ if (node_def.op() == "FakeQuantWithMinMaxArgs") {
+ // Get ranges via node attributes.
+ TFAttrs attrs(node_def);
+ if (attrs.count("min") == 0 || attrs.count("max") == 0) {
+ return errors::InvalidArgument("Min or max attribute not found for ",
+ node_def.op(), " at ", node_def.name());
+ }
+ min_range = attrs.get<float>("min");
+ max_range = attrs.get<float>("max");
+ } else if (node_def.op() == "FakeQuantWithMinMaxVars" ||
+ node_def.op() == "QuantizeAndDequantizeV2" ||
+ node_def.op() == "QuantizeAndDequantizeV3") {
+ // Get ranges via inputs.
+ if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights()) {
+ return errors::InvalidArgument("Min and max inputs for ", node_def.op(),
+ " must be weights not tensors, at ",
+ node_def.name());
+ }
+ auto get_weights_value = [&inputs](int index) {
+ auto raw_weights = static_cast<float*>(
+ const_cast<void*>(inputs.at(index).weights().GetValues()));
+ return raw_weights[0];
+ };
+ min_range = get_weights_value(1);
+ max_range = get_weights_value(2);
+ } else {
+ return errors::InvalidArgument("Unknown quantization op ", node_def.op(),
+ ", at ", node_def.name());
+ }
+ if (params->validation_only) return Status::OK();
+
+ // Store ranges for tensor
+ params->converter->ProvideQuantizationRange(
+ const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()), min_range,
+ max_range);
+ // Sometimes, TRT may not quantize a tensor, either because it chooses to
+ // execute a higher precision kernel or because of op fusion. In these cases,
+ // accuracy will suffer if the model was trained to expect quantization at
+ // that tensor. We should consider adding a clip(tensor, min_range, max_range)
+ // operation here to ensure that any arbitrarily placed quantize node will
+ // execute as expected. However, this will negatively affect performance. If
+ // users train their models in a way which models inference as close as
+ // possible (i.e. not quantizing in place where fusion will occur), then there
+ // is no problem with the current implementation.
+ params->outputs->push_back(inputs.at(0));
+ return Status::OK();
+}
+
+// TODO(pdavoodi): we should update relu6 implementation once TensorRT supports
+// Relu6 natively.
+tensorflow::Status ConvertRelu6(OpConverterParams* params) {
+ const auto& inputs = params->inputs;
+ const auto& node_def = params->node_def;
+ if (inputs.size() != 1) {
+ return tensorflow::errors::InvalidArgument(
+ "Invalid number of inputs for Relu6, at ", node_def.name());
+ }
+ if (inputs.at(0).is_weights()) {
+ return tensorflow::errors::Unimplemented(
+ "Relu6 is only implemented for tensors, not weights, at ",
+ node_def.name());
+ }
+ if (params->validation_only) return Status::OK();
+ // ***************************************************************************
+ // TensorRT does not implement Relu6 natively. This function converts Relu6 op
+ // to available TensorRT ops: Relu6(x) = min(Relu(x), 6)
+ // ***************************************************************************
+
+ // Input Tensor
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+
+ // Relu operation i.e. Relu(x) = max(0, x)
+ nvinfer1::IActivationLayer* relu_layer =
+ params->converter->network()->addActivation(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ nvinfer1::ActivationType::kRELU);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(relu_layer, node_def.name());
+
+ // Large range of relu is problematic during quantization in INT8 precision
+ // mode. Setting dynamic range of relu = [0.f, 6.0f] helps with quantization.
+ // TRT only uses dynamic ranges in INT8 precision mode,
+ // and this does not affect the FP32 path.
+ params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f,
+ 6.0f);
+
+ // Create a constant layer to store the floating point weight i.e. 6.0f This
+ // tensor will be broadcasted uniformly during elementwise `min` operation.
+ // The constant has to have the same rank as the input in order for TRT to
+ // broadcast
+ nvinfer1::Dims dims;
+ dims.nbDims = relu_layer->getOutput(0)->getDimensions().nbDims;
+ for (int i = 0; i < dims.nbDims; i++) {
+ dims.d[i] = 1;
+ }
+ TRT_ShapedWeights weights = params->weight_store->GetTempWeights(
+ tensorflow::DataType::DT_FLOAT, dims);
+ auto weights_ptr =
+ static_cast<float*>(const_cast<void*>(weights.GetValues()));
+ weights_ptr[0] = 6.0f;
+ nvinfer1::IConstantLayer* const6_layer =
+ params->converter->network()->addConstant(dims, weights.GetTrtWeights());
+ TFTRT_RETURN_ERROR_IF_NULLPTR(const6_layer, node_def.name());
+ params->converter->ProvideQuantizationRange(const6_layer->getOutput(0), 0.0f,
+ 6.0f);
+
+ // ElementWise Min Operation
+ // Min op is a nop for INT8 execution path, as the input tensor
+ // to this layer will only have values in range [0.f, 6.0f].
+ const nvinfer1::ITensor* tensor_l = relu_layer->getOutput(0);
+ const nvinfer1::ITensor* tensor_r = const6_layer->getOutput(0);
+ nvinfer1::IElementWiseLayer* relu6_layer =
+ params->converter->network()->addElementWise(
+ *const_cast<nvinfer1::ITensor*>(tensor_l),
+ *const_cast<nvinfer1::ITensor*>(tensor_r),
+ nvinfer1::ElementWiseOperation::kMIN);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name());
+ nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0);
+ params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f);
+
+ params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return Status::OK();
+}
+
tensorflow::Status ConvertBiasAdd(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
@@ -1786,7 +2140,8 @@
}
if (params->validation_only) return Status::OK();
- const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ nvinfer1::ITensor* tensor =
+ const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor());
const nvinfer1::Dims original_dims = tensor->getDimensions();
TFAttrs attrs(node_def);
const string data_format = attrs.get<string>("data_format");
@@ -1802,18 +2157,20 @@
}
permutation.order[0] = channel_index;
permutation.order[channel_index] = 0;
+ VLOG(1) << "ConvertBiasAdd permutation: "
+ << DebugString(permutation, original_dims.nbDims);
}
- VLOG(1) << "ConvertBiasAdd permutation: "
- << DebugString(permutation, original_dims.nbDims);
// TensorRT addScale requires input to be of rank 3, we need to apply
// transpose as well as reshape.
// TODO(laigd): this doesn't match what the TRT doc says, fix the doc?
if (channel_index != 0 || original_dims.nbDims != 3) {
nvinfer1::IShuffleLayer* shuffle_layer =
- params->converter->network()->addShuffle(
- *const_cast<nvinfer1::ITensor*>(tensor));
+ params->converter->network()->addShuffle(*tensor);
TFTRT_RETURN_ERROR_IF_NULLPTR(shuffle_layer, node_def.name());
+ params->converter->MarkQuantizationRangesAsInferrable(
+ tensor, shuffle_layer->getOutput(0));
+
// NOTE(laigd): for some reason we need to apply the reshape
// unconditionally. The default shape has nbDims==-1 and it seems the
// behavior is undefined in some cases.
@@ -1832,7 +2189,7 @@
}
TRT_ShapedWeights weights = inputs.at(1).weights();
- if (params->converter->is_fp16()) {
+ if (params->converter->precision_mode() == FP16MODE) {
weights = ConvertFP32ToFP16(params->weight_store, weights);
}
nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
@@ -1842,8 +2199,8 @@
TRT_ShapedWeights empty_weights(weights.type_);
nvinfer1::IScaleLayer* layer = params->converter->network()->addScale(
- *const_cast<nvinfer1::ITensor*>(tensor), mode, weights.GetTrtWeights(),
- empty_weights.GetTrtWeights(), empty_weights.GetTrtWeights());
+ *tensor, mode, weights.GetTrtWeights(), empty_weights.GetTrtWeights(),
+ empty_weights.GetTrtWeights());
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
@@ -1867,6 +2224,8 @@
if (channel_index != 0) {
shuffle_layer->setSecondTranspose(permutation);
}
+ params->converter->MarkQuantizationRangesAsInferrable(
+ output_tensor, shuffle_layer->getOutput(0));
output_tensor = shuffle_layer->getOutput(0);
}
@@ -2025,32 +2384,41 @@
}
tensorflow::Status ConvertIdentity(OpConverterParams* params) {
+ // TODO(tmorris): TRT's Identity layer does not get optimized away as of TRT
+ // 5.0, however once we know that it does it would be nice to use that
+ // instead.
params->outputs->push_back(params->inputs.at(0));
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertBinary(OpConverterParams* params) {
+Status ConvertBinary(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
if (inputs.size() != 2) {
- return tensorflow::errors::FailedPrecondition(
- "Binary ops require two tensor input, at ", node_def.name());
+ return errors::InvalidArgument("Binary ops require two inputs, at ",
+ node_def.name());
}
// Constant folding should have been done by TensorFlow
-
if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
- return tensorflow::errors::Unimplemented(
+ return errors::Unimplemented(
"Constant folding is falled back to TensorFlow, binary op received "
"both input as constant at: ",
node_def.name());
}
- // Try to convert into Scale layer first (for better performance)
+ // TODO(tmorris): TRT plans to deprecate IScaleLayer and will replace it with
+ // IElementwiseLayer. At that point, we can remove BinaryTensorOpWeight. For
+ // now, the performance will be slightly better with IScaleLayer because it
+ // can be fused in more situations. However, most of the benefits of
+ // IScaleLayer are when the layer performs both a shift and a scale, which we
+ // don't do except for convolutions.
+ //
+ // Try to convert into Scale layer first (for better performance).
// Since scale layer supports restricted broadcast policy and op types, we
// allow failure and try to handle it through Elementwise op
- // (BinaryTensorOpTensor)
- Status status = tensorflow::Status::OK();
+ // (BinaryTensorOpTensor).
+ Status status = Status::OK();
if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) {
status = BinaryTensorOpWeight(params, inputs.at(0).tensor(),
inputs.at(1).weights(), false);
@@ -2058,7 +2426,10 @@
status = BinaryTensorOpWeight(params, inputs.at(1).tensor(),
inputs.at(0).weights(), true);
}
+ // If both input are tensors, or one of them is weights but the conversion
+ // above failed, try the conversion using BinaryTensorOpTensor.
if ((inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) || !status.ok()) {
+ if (!status.ok()) VLOG(1) << status;
status = BinaryTensorOpTensor(params, inputs.at(0), inputs.at(1));
}
return status;
@@ -2088,6 +2459,20 @@
nvinfer1::IUnaryLayer* layer;
if (node_def.op() == "Rsqrt") {
+ // We will need a quantization range for intermediate tensor if not using
+ // calibration.
+ //
+ // x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x)
+ // ^
+ // need range here
+ if (params->converter->precision_mode() == INT8MODE &&
+ !params->converter->use_calibration()) {
+ return errors::Unimplemented(
+ "Intermediate quantization range cannot be determined without"
+ " calibration for Rsqrt, consider replacing with "
+ "Sqrt -> FakeQuant -> Reciprocal ops, at ",
+ node_def.name());
+ }
layer = params->converter->network()->addUnary(
*const_cast<nvinfer1::ITensor*>(tensor),
nvinfer1::UnaryOperation::kSQRT);
@@ -2647,6 +3032,8 @@
layer->setAxes(1 << (nbDims - 1));
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ // Quantization range for SoftMax is always (0, 1)
+ params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 1.0f);
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return tensorflow::Status::OK();
}
@@ -2687,41 +3074,49 @@
return tensorflow::Status::OK();
}
-void TrtNodeValidator::RegisterOpValidators() {
+static void RegisterValidatableOpConverters(
+ std::unordered_map<string, OpConverter>* registration) {
// TODO(laigd): support all op types.
- op_validators_["BiasAdd"] = ConvertBiasAdd;
- op_validators_["Const"] = ConvertConst;
- op_validators_["Transpose"] = ConvertTranspose;
- op_validators_["Reshape"] = ConvertReshape;
- op_validators_["MatMul"] = ConvertMatMul;
+ (*registration)["BiasAdd"] = ConvertBiasAdd;
+ (*registration)["Const"] = ConvertConst;
+ (*registration)["Transpose"] = ConvertTranspose;
+ (*registration)["Reshape"] = ConvertReshape;
+ (*registration)["MatMul"] = ConvertMatMul;
+ (*registration)["Relu6"] = ConvertRelu6;
+
+ for (auto quantization_op_type :
+ {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3",
+ "FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxArgs"}) {
+ (*registration)[quantization_op_type] = ConvertQuantize;
+ }
+ for (auto binary_op_type :
+ {"Add", "Mul", "Sub", "Div", "RealDiv", "Maximum", "Minimum"}) {
+ (*registration)[binary_op_type] = ConvertBinary;
+ }
+}
+
+void TrtNodeValidator::RegisterOpValidators() {
+ RegisterValidatableOpConverters(&op_validators_);
}
void Converter::RegisterOpConverters() {
- // vgg_16 slim implementation
+ RegisterValidatableOpConverters(&op_registry_);
+
op_registry_["Conv2D"] = ConvertConv2D;
op_registry_["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
op_registry_["Relu"] = ConvertActivation;
op_registry_["MaxPool"] = ConvertPool;
op_registry_["AvgPool"] = ConvertPool;
- op_registry_["BiasAdd"] = ConvertBiasAdd;
- op_registry_["Const"] = ConvertConst;
// TODO(ben,jie): this is a temp hack.
op_registry_["Identity"] = ConvertIdentity; // Identity should be removed
op_registry_["Snapshot"] = ConvertIdentity; // Snapshot should be removed
- // resnet_50_v1 slim implementation
- op_registry_["Add"] = ConvertBinary;
- op_registry_["Mul"] = ConvertBinary;
- op_registry_["Sub"] = ConvertBinary;
op_registry_["Pad"] = ConvertPad;
op_registry_["ConcatV2"] = ConvertConcat;
op_registry_["FusedBatchNorm"] = ConvertFusedBatchNorm;
op_registry_["FusedBatchNormV2"] = ConvertFusedBatchNorm;
- op_registry_["Div"] = ConvertBinary;
- op_registry_["RealDiv"] = ConvertBinary;
-
op_registry_["Rsqrt"] = ConvertUnary;
op_registry_["Reciprocal"] = ConvertUnary;
op_registry_["Exp"] = ConvertUnary;
@@ -2730,18 +3125,12 @@
op_registry_["Abs"] = ConvertUnary;
op_registry_["Neg"] = ConvertUnary;
- op_registry_["Transpose"] = ConvertTranspose;
- op_registry_["Reshape"] = ConvertReshape;
-
op_registry_["Sum"] = ConvertReduce;
op_registry_["Prod"] = ConvertReduce;
op_registry_["Max"] = ConvertReduce;
op_registry_["Min"] = ConvertReduce;
op_registry_["Mean"] = ConvertReduce;
- op_registry_["Maximum"] = ConvertBinary;
- op_registry_["Minimum"] = ConvertBinary;
op_registry_["Softmax"] = ConvertSoftmax;
- op_registry_["MatMul"] = ConvertMatMul;
op_registry_["BatchMatMul"] = ConvertBatchMatMul;
op_registry_["TopKV2"] = ConvertTopK;
@@ -2754,7 +3143,7 @@
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
Logger* logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
- TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
+ TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
bool* convert_successfully) {
engine->reset();
if (convert_successfully) *convert_successfully = false;
@@ -2769,7 +3158,11 @@
builder->setHalf2Mode(true);
} else if (precision_mode == INT8MODE) {
builder->setInt8Mode(true);
- builder->setInt8Calibrator(calibrator);
+ if (use_calibration) {
+ builder->setInt8Calibrator(calibrator);
+ } else {
+ builder->setInt8Calibrator(nullptr);
+ }
}
// Create the network.
@@ -2782,7 +3175,7 @@
// Build the network
VLOG(1) << "Starting engine conversion ";
- Converter converter(trt_network.get(), precision_mode == FP16MODE);
+ Converter converter(trt_network.get(), precision_mode, use_calibration);
std::vector<std::pair<string, string>> output_tensors;
// Graph nodes are already topologically sorted during construction
for (const auto& node_def : gdef.node()) {
@@ -2838,6 +3231,9 @@
TF_RETURN_IF_ERROR(converter.RenameAndMarkOutputTensors(output_tensors));
if (convert_successfully) *convert_successfully = true;
+ // Apply user provided quantization ranges to tensors
+ converter.MaybeApplyQuantizationRanges();
+
// Build the engine.
VLOG(1) << "Starting engine creation";
engine->reset(builder->buildCudaEngine(*converter.network()));
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 5cc28b3..54e19b7 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -92,7 +92,8 @@
EngineInfo()
: engine_type(EngineType::TRTStatic),
max_workspace_size_bytes(0),
- precision_mode(FP32MODE) {}
+ precision_mode(FP32MODE),
+ use_calibration(true) {}
string engine_name;
string device;
@@ -109,6 +110,7 @@
int maximum_cached_engines;
std::vector<int> cached_engine_batches;
int precision_mode;
+ bool use_calibration;
};
// Constructs a graphdef from the segment in the given graph. Adds placeholder
@@ -145,7 +147,7 @@
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
Logger* logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
- TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
+ TrtUniquePtrType<nvinfer1::ICudaEngine>* engine, bool use_calibration,
bool* convert_successfully);
// Helper class for the segmenter to determine whether an output edge from the
@@ -392,7 +394,8 @@
// Class to convert TF nodes to TRT network.
class Converter {
public:
- Converter(nvinfer1::INetworkDefinition* trt_network, bool is_fp16);
+ Converter(nvinfer1::INetworkDefinition* trt_network, int precision_mode,
+ bool use_calibration);
//////////////////////////////////////////////////////////////////////////////
// Methods used by the TRT engine builder to build a TRT network from a TF
@@ -422,8 +425,27 @@
// to add TRT layers.
nvinfer1::INetworkDefinition* network() { return trt_network_; }
- // Is the converter operating in fp16 mode?
- bool is_fp16() const { return is_fp16_; }
+ // What precision are we targeting?
+ int precision_mode() const { return precision_mode_; }
+
+ // Calibration will be or was previously performed on this network?
+ bool use_calibration() const { return use_calibration_; }
+
+ // This should be called on the inputs and outputs of any layer we create
+ // where we know that the quantization range does not change during that
+ // operation. (e.g. Reshape, Transpose, Identity, MaxPool).
+ void MarkQuantizationRangesAsInferrable(nvinfer1::ITensor* input,
+ nvinfer1::ITensor* output);
+
+ // This function should be called when we know the quantization range of a
+ // tensor, either from a quantize/dequantize node or when the output is a
+ // fixed range (e.g. SoftMax, Relu6, Sigmoid).
+ void ProvideQuantizationRange(nvinfer1::ITensor* tensor, float min_range,
+ float max_range);
+
+ // Should be called when full TRT network has been constructed and before
+ // building the engine.
+ void MaybeApplyQuantizationRanges();
// Below are helper methods for op converters to add different layers to the
// TRT network.
@@ -440,6 +462,13 @@
const nvinfer1::Dims& dims,
const nvinfer1::ITensor** tensor);
+ // Return OK if the broadcast scheme is supported and compute the shapes after
+ // broadcasting.
+ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
+ const TRT_TensorOrWeights& operand_r,
+ nvinfer1::Dims* operand_l_new_dims,
+ nvinfer1::Dims* operand_r_new_dims) const;
+
private:
// Verify the provided batch_size is consistent with batch_size_ and update it
// if necessary.
@@ -457,6 +486,12 @@
void RegisterOpConverters();
+ void PropagateQuantizationRanges();
+
+ // Gets the min and max value in a TRT_ShapedWeights
+ Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
+ float* out_max) const;
+
// Registered op converters by op type.
std::unordered_map<string, OpConverter> op_registry_;
@@ -472,7 +507,25 @@
// Store the weights added during construction of trt_network_.
TrtWeightStore weight_store_;
- const bool is_fp16_;
+ // During conversion, this table is populated with quantization ranges per
+ // tensor. MaybeApplyQuantizationRanges() will use this table to set the TRT
+ // quantization ranges. Since TRT only supports symmetric ranges, we will
+ // store the range as a single float = max(abs(min_range), abs(max_range)).
+ // Range refers to the floating point values, e.g. min_range = 0.0f, max_range
+ // = 6.0f for Relu6.
+ std::unordered_map<nvinfer1::ITensor*, float> quantization_ranges_;
+
+ // Edges where quantization ranges can be inferred (copied) across ops - from
+ // first tensor to second tensor. PropagateQuantizationRanges() will propagate
+ // known ranges from quantization_ranges_ across these edges, adding the new
+ // ranges to quantization_ranges_ so that they can be applied in
+ // MaybeApplyQuantizationRanges().
+ std::vector<std::pair<nvinfer1::ITensor*, nvinfer1::ITensor*>>
+ quantization_infer_;
+
+ const int precision_mode_;
+
+ const bool use_calibration_;
// Batch size of inputs to trt_network_ added by AddInputTensor(). During
// network construction it will update this, use it to verify the batch
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc
index 862754f..603c4f7 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc
@@ -35,6 +35,7 @@
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
#include "tensorflow/core/public/session.h"
@@ -49,7 +50,9 @@
namespace tensorrt {
namespace convert {
+using ::tensorflow::strings::StrCat;
using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
// TODO(laigd): put this into some test utils file.
void ExpectStatus(Status status, error::Code code = error::OK,
@@ -71,6 +74,32 @@
return dims;
}
+nvinfer1::DataType TfDataTypeToTrt(DataType tf_dtype) {
+ switch (tf_dtype) {
+ case DT_FLOAT:
+ return nvinfer1::DataType::kFLOAT;
+ case DT_HALF:
+ return nvinfer1::DataType::kHALF;
+ case DT_INT32:
+ return nvinfer1::DataType::kINT32;
+ default:
+ QCHECK(false) << "Unexpected data type " << DataTypeString(tf_dtype);
+ }
+}
+
+DataType TrtDataTypeToTf(nvinfer1::DataType trt_dtype) {
+ switch (trt_dtype) {
+ case nvinfer1::DataType::kFLOAT:
+ return DT_FLOAT;
+ case nvinfer1::DataType::kHALF:
+ return DT_HALF;
+ case nvinfer1::DataType::kINT32:
+ return DT_INT32;
+ default:
+ QCHECK(false) << "Unexpected data type " << static_cast<int>(trt_dtype);
+ }
+}
+
NodeDef MakeNodeDef(const string& name, const string& op,
const std::vector<string>& inputs) {
NodeDef node_def;
@@ -113,6 +142,15 @@
return TrtDimsEquals(GetTestDims(lhs), rhs);
}
+// TODO(laigd): define a parameterized matcher that can compare against the
+// vector.
+void ExpectTrtDimsEqualsArray(const std::vector<int>& lhs,
+ const nvinfer1::Dims& rhs) {
+ EXPECT_TRUE(TrtDimsEqualsArray(lhs, rhs))
+ << "expected: " << DebugString(GetTestDims(lhs)) << "\n"
+ << " actual: " << DebugString(rhs);
+}
+
bool TrtShapedWeightsEquals(const TRT_ShapedWeights& lhs,
const TRT_ShapedWeights& rhs) {
return TrtDimsEquals(lhs.shape_, rhs.shape_) && lhs.type_ == rhs.type_ &&
@@ -123,8 +161,7 @@
void ValidateWeights(const TRT_ShapedWeights& weights,
const std::vector<int>& expected_dims,
const std::vector<T>& expected_value) {
- EXPECT_TRUE(TrtDimsEqualsArray(expected_dims, weights.shape_))
- << weights.DebugString();
+ ExpectTrtDimsEqualsArray(expected_dims, weights.shape_);
ASSERT_EQ(expected_value.size(), weights.count()) << weights.DebugString();
const T* actual_values = static_cast<const T*>(weights.GetValues());
for (int i = 0; i < expected_value.size(); ++i) {
@@ -135,11 +172,12 @@
// Fake ITensor implementation for testing purposes.
class FakeITensor : public nvinfer1::ITensor {
public:
- FakeITensor() {}
+ FakeITensor() : dynamic_range_(0.0f) {}
- FakeITensor(const nvinfer1::Dims& dims) : dims_(dims) {}
+ FakeITensor(const nvinfer1::Dims& dims) : dims_(dims), dynamic_range_(0.0f) {}
- FakeITensor(const std::vector<int>& dims) : dims_(GetTestDims(dims)) {}
+ FakeITensor(const std::vector<int>& dims)
+ : dims_(GetTestDims(dims)), dynamic_range_(0.0f) {}
void setName(const char* name) override { name_ = name; }
@@ -168,7 +206,12 @@
}
#if NV_TENSORRT_MAJOR >= 5
- bool setDynamicRange(float min, float max) override {}
+ bool setDynamicRange(float min, float max) override {
+ dynamic_range_ = std::max(std::abs(min), std::abs(max));
+ return true;
+ }
+
+ float getDynamicRange() const override { return dynamic_range_; }
#endif
private:
@@ -176,6 +219,7 @@
nvinfer1::Dims dims_;
nvinfer1::DataType type_;
nvinfer1::TensorLocation location_;
+ float dynamic_range_;
};
TEST(TRT_ShapedWeights_Test, Basic) {
@@ -267,9 +311,7 @@
EXPECT_EQ(1, ptr->batch_size());
}
EXPECT_EQ(&itensor, ptr->tensor());
- EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims()))
- << "- expected: " << DebugString(dims)
- << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims());
+ ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims());
}
}
}
@@ -288,9 +330,7 @@
EXPECT_EQ(false, ptr->is_weights());
EXPECT_EQ(1, ptr->batch_size());
EXPECT_NE(nullptr, ptr->tensor());
- EXPECT_TRUE(TrtDimsEqualsArray({1}, ptr->GetTrtDims()))
- << "- expected: " << DebugString(dims)
- << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims());
+ ExpectTrtDimsEqualsArray({1}, ptr->GetTrtDims());
}
}
// Test constructor with TRT_ShapedWeights argument.
@@ -307,9 +347,7 @@
nvinfer1::Dims dims;
dims.nbDims = 0;
- EXPECT_TRUE(TrtDimsEqualsArray({}, ptr->GetTrtDims()))
- << "- expected: " << DebugString(dims)
- << "\n vs\n- actual: " << DebugString(ptr->GetTrtDims());
+ ExpectTrtDimsEqualsArray({}, ptr->GetTrtDims());
}
}
}
@@ -386,9 +424,7 @@
EXPECT_EQ(true, output.is_tensor());
EXPECT_EQ(batch_size, output.batch_size());
EXPECT_NE(nullptr, output.tensor());
- EXPECT_TRUE(TrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims()))
- << "- expected: {" << non_batch_dim << "} \n vs\n"
- << "- actual: " << DebugString(output.GetTrtDims());
+ ExpectTrtDimsEqualsArray({non_batch_dim}, output.GetTrtDims());
}
}
@@ -425,7 +461,9 @@
ConverterTest() {
builder_.reset(nvinfer1::createInferBuilder(logger_));
network_.reset(builder_->createNetwork());
- converter_.reset(new Converter(network_.get(), /*fp16=*/false));
+ converter_.reset(new Converter(network_.get(),
+ /*precision_mode=*/FP32MODE,
+ /*use_calibration=*/false));
weight_store_ = &converter_->weight_store_;
}
@@ -452,8 +490,21 @@
return converter_->GetInputs(node_def, inputs);
}
+ Status GetWeightRange(const TRT_ShapedWeights& weights, float* out_min,
+ float* out_max) const {
+ return converter_->GetWeightRange(weights, out_min, out_max);
+ }
+
+ void PropagateQuantizationRanges() {
+ converter_->PropagateQuantizationRanges();
+ }
+
int batch_size() const { return converter_->batch_size_; }
+ std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
+ return converter_->quantization_ranges_;
+ }
+
private:
Logger logger_;
// These members are ordered in a way such that the destruction order is:
@@ -524,9 +575,9 @@
EXPECT_EQ(nvinfer1::DataType::kFLOAT, inputs[0].tensor()->getType());
EXPECT_EQ(nvinfer1::DataType::kINT32, inputs[2].tensor()->getType());
EXPECT_EQ(nvinfer1::DataType::kHALF, inputs[3].tensor()->getType());
- EXPECT_TRUE(TrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions()));
- EXPECT_TRUE(TrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions()));
- EXPECT_TRUE(TrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions()));
+ ExpectTrtDimsEqualsArray({1}, inputs[0].tensor()->getDimensions());
+ ExpectTrtDimsEqualsArray({2, 3}, inputs[2].tensor()->getDimensions());
+ ExpectTrtDimsEqualsArray({5, 3}, inputs[3].tensor()->getDimensions());
}
TEST_F(ConverterTest, RenameAndMarkOutputTensors) {
@@ -572,7 +623,7 @@
{{"my_op", "my_output"}, {"my_op:1", "my_output_1"}}));
EXPECT_EQ(2, output_tensors.size());
for (auto output_tensor : output_tensors) {
- EXPECT_TRUE(TrtDimsEqualsArray({2, 1}, output_tensor->getDimensions()));
+ ExpectTrtDimsEqualsArray({2, 1}, output_tensor->getDimensions());
}
EXPECT_EQ("my_output", string(output_tensors[0]->getName()));
EXPECT_EQ("my_output_1", string(output_tensors[1]->getName()));
@@ -597,8 +648,7 @@
// OK.
TF_EXPECT_OK(
converter_->TransposeTensor(input_tensor, {0, 3, 1, 2}, &output_tensor));
- EXPECT_TRUE(TrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions()))
- << DebugString(*output_tensor);
+ ExpectTrtDimsEqualsArray({5, 2, 3}, output_tensor->getDimensions());
}
TEST_F(ConverterTest, PrepareTensorForShape_Tensor) {
@@ -610,7 +660,7 @@
// Shape size doesn't match.
ExpectStatus(converter_->PrepareTensorForShape(tw, GetTestDims({2, 3, 6}),
&output_tensor),
- error::INVALID_ARGUMENT, "Reshape shapes are not compatible.");
+ error::INVALID_ARGUMENT, "Reshape shapes are not compatible");
// TODO(aaroey): we should check the case where uninferred dimensions are not
// an exact divisor of input dim ensions, e.g. for dims {-1, 7}.
@@ -618,14 +668,12 @@
// Infer shape, ok.
TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({-1, 2}),
&output_tensor));
- EXPECT_TRUE(TrtDimsEqualsArray({15, 2}, output_tensor->getDimensions()))
- << DebugString(*output_tensor);
+ ExpectTrtDimsEqualsArray({15, 2}, output_tensor->getDimensions());
// Regular shape.
TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}),
&output_tensor));
- EXPECT_TRUE(TrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()))
- << DebugString(*output_tensor);
+ ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
}
TEST_F(ConverterTest, PrepareTensorForShape_Weights) {
@@ -635,8 +683,7 @@
const nvinfer1::ITensor* output_tensor = nullptr;
TF_EXPECT_OK(converter_->PrepareTensorForShape(tw, GetTestDims({10, 3}),
&output_tensor));
- EXPECT_TRUE(TrtDimsEqualsArray({10, 3}, output_tensor->getDimensions()))
- << DebugString(*output_tensor);
+ ExpectTrtDimsEqualsArray({10, 3}, output_tensor->getDimensions());
}
TEST_F(ConverterTest, MaybeUpdateBatchSize) {
@@ -676,6 +723,178 @@
"tensor/weights my_tensor already exist");
}
+template <typename T>
+void TestGetWeightRange(ConverterTest* test, TrtWeightStore* weight_store) {
+ TRT_ShapedWeights weights =
+ weight_store->GetTempWeights(DataTypeToEnum<T>::v(), GetTestDims({2, 3}));
+ const std::vector<T> values = {T(3), T(1), T(2), T(6), T(5), T(4)};
+ memcpy(const_cast<void*>(weights.GetValues()), values.data(),
+ weights.size_bytes());
+
+ float out_min = 0.0f;
+ float out_max = 0.0f;
+ TF_EXPECT_OK(test->GetWeightRange(weights, &out_min, &out_max));
+ EXPECT_EQ(1.0f, out_min);
+ EXPECT_EQ(6.0f, out_max);
+}
+
+TEST_F(ConverterTest, GetWeightRange) {
+ TestGetWeightRange<float>(this, weight_store_);
+ TestGetWeightRange<Eigen::half>(this, weight_store_);
+ TestGetWeightRange<int32>(this, weight_store_);
+}
+
+TEST_F(ConverterTest, ProvideQuantizationRange) {
+ FakeITensor fake_tensor;
+ // Assymetric range
+ converter_->ProvideQuantizationRange(&fake_tensor, 0.0f, 6.0f);
+ EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]);
+ converter_->ProvideQuantizationRange(&fake_tensor, 1.0f, 6.0f);
+ EXPECT_EQ(6.0f, quantization_ranges()[&fake_tensor]);
+ converter_->ProvideQuantizationRange(&fake_tensor, -8.0f, 6.0f);
+ EXPECT_EQ(8.0f, quantization_ranges()[&fake_tensor]);
+ converter_->ProvideQuantizationRange(&fake_tensor, -8.123f, -6.123f);
+ EXPECT_EQ(8.123f, quantization_ranges()[&fake_tensor]);
+ // Symmetric range
+ converter_->ProvideQuantizationRange(&fake_tensor, -6.123f, 6.123f);
+ EXPECT_EQ(6.123f, quantization_ranges()[&fake_tensor]);
+}
+
+TEST_F(ConverterTest, MaybeApplyQuantizationRanges) {
+ // input -> infer1 -> infer2 -> infer3
+ FakeITensor input, infer_1, infer_2, infer_3;
+ FakeITensor not_infer;
+ Converter int8_converter(/*trt_network=*/nullptr, INT8MODE,
+ /*use_calibration=*/true);
+ int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f);
+ int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f);
+ int8_converter.MarkQuantizationRangesAsInferrable(&input, &infer_1);
+ int8_converter.MarkQuantizationRangesAsInferrable(&infer_1, &infer_2);
+ int8_converter.MarkQuantizationRangesAsInferrable(&infer_2, &infer_3);
+
+ // Input range should be inferred along the chain and applied to tensors.
+ int8_converter.MaybeApplyQuantizationRanges();
+#if NV_TENSORRT_MAJOR >= 5
+ EXPECT_EQ(input.getDynamicRange(), 5.0f);
+ EXPECT_EQ(infer_1.getDynamicRange(), 5.0f);
+ EXPECT_EQ(infer_2.getDynamicRange(), 5.0f);
+ EXPECT_EQ(infer_3.getDynamicRange(), 5.0f);
+ EXPECT_EQ(not_infer.getDynamicRange(), 100.0f);
+#endif
+}
+
+TEST_F(ConverterTest, PropagateQuantizationRanges) {
+ // infer0 <-> infer1 <-> infer2 <-> infer3
+ // |
+ // infer4 <-> infer5
+ FakeITensor infer[6];
+ FakeITensor not_infer;
+ converter_->ProvideQuantizationRange(&infer[4], -5.0f, 5.0f);
+ converter_->MarkQuantizationRangesAsInferrable(&infer[0], &infer[1]);
+ converter_->MarkQuantizationRangesAsInferrable(&infer[1], &infer[2]);
+ converter_->MarkQuantizationRangesAsInferrable(&infer[3], &infer[2]);
+ converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[1]);
+ converter_->MarkQuantizationRangesAsInferrable(&infer[4], &infer[5]);
+
+ // Input range should be inferred along the chain.
+ PropagateQuantizationRanges();
+ auto ranges = quantization_ranges();
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(5.0f, ranges[&infer[i]]);
+ }
+ EXPECT_EQ(ranges.count(¬_infer), 0);
+}
+
+TEST_F(ConverterTest, GetTrtBroadcastShape) {
+ const bool kIsTensor = true;
+ const bool kIsNotTensor = false;
+ auto symmetric_test = [this](const std::vector<int>& operand_1_shape,
+ const std::vector<int>& operand_2_shape,
+ const bool operand_1_is_tensor,
+ const bool operand_2_is_tensor,
+ const std::vector<int>& expected_operand_1_shape,
+ const std::vector<int>& expected_operand_2_shape,
+ error::Code expected_code = error::OK,
+ const char* expected_error_msg_substr = nullptr,
+ const int operand_1_batch_size = -1,
+ const int operand_2_batch_size = -1) {
+ auto create_tensor_or_weights = [](const std::vector<int>& shape,
+ bool is_tensor, int batch_size = -1) {
+ if (is_tensor) {
+ return TRT_TensorOrWeights{nvinfer1::DataType::kFLOAT,
+ GetTestDims(shape), batch_size};
+ }
+ TRT_ShapedWeights weights;
+ weights.shape_ = GetTestDims(shape);
+ return TRT_TensorOrWeights(weights);
+ };
+
+ nvinfer1::Dims operand_1_new_dims, operand_2_new_dims;
+ TRT_TensorOrWeights operand_1 = create_tensor_or_weights(
+ operand_1_shape, operand_1_is_tensor, operand_1_batch_size);
+ TRT_TensorOrWeights operand_2 = create_tensor_or_weights(
+ operand_2_shape, operand_2_is_tensor, operand_2_batch_size);
+
+ // operand_1 broadcast operand_2
+ ExpectStatus(
+ this->converter_->GetTrtBroadcastShape(
+ operand_1, operand_2, &operand_1_new_dims, &operand_2_new_dims),
+ expected_code, expected_error_msg_substr);
+ if (expected_code == error::OK) {
+ ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
+ ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
+ }
+ // operand_2 broadcast operand_1
+ ExpectStatus(
+ this->converter_->GetTrtBroadcastShape(
+ operand_2, operand_1, &operand_2_new_dims, &operand_1_new_dims),
+ expected_code, expected_error_msg_substr);
+ if (expected_code == error::OK) {
+ ExpectTrtDimsEqualsArray(expected_operand_1_shape, operand_1_new_dims);
+ ExpectTrtDimsEqualsArray(expected_operand_2_shape, operand_2_new_dims);
+ }
+ };
+
+ // Both inputs are weights.
+ symmetric_test(
+ {1}, {1}, kIsNotTensor, kIsNotTensor, {}, {}, error::INVALID_ARGUMENT,
+ "Broadcasting requires at least one of the operands be tensors");
+
+ // One tensor and one weights.
+ symmetric_test({1, 1, 1}, {2}, kIsTensor, kIsNotTensor, {1, 1, 1}, {1, 1, 2});
+ symmetric_test({1, 1, 2}, {2}, kIsTensor, kIsNotTensor, {1, 1, 2}, {1, 1, 2});
+ symmetric_test({1, 3, 2}, {1}, kIsTensor, kIsNotTensor, {1, 3, 2}, {1, 1, 1});
+ symmetric_test({1, 1, 1}, {2, 3}, kIsTensor, kIsNotTensor, {1, 1, 1},
+ {1, 2, 3});
+ symmetric_test({1, 1, 1}, {2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1},
+ {2, 3, 4});
+ symmetric_test({1, 1, 1}, {1, 2, 3, 4}, kIsTensor, kIsNotTensor, {1, 1, 1},
+ {2, 3, 4});
+ symmetric_test({1, 3, 4}, {1, 2, 1, 4}, kIsTensor, kIsNotTensor, {1, 3, 4},
+ {2, 1, 4});
+ symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
+ error::INVALID_ARGUMENT, "Infeasible broadcast scheme");
+ symmetric_test({1, 1, 1}, {2, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
+ error::INVALID_ARGUMENT, "Infeasible broadcast scheme",
+ /*operand_1_batch_size=*/2);
+ symmetric_test({1, 1, 1}, {1, 1, 1, 1, 1}, kIsTensor, kIsNotTensor, {}, {},
+ error::INVALID_ARGUMENT,
+ "Broadcasting beyond batch dimension is not supported "
+ "(tensor #dims 4 vs broadcast #dims 5)");
+
+ // Both inputs are tensors.
+ symmetric_test({1, 1, 1}, {1, 1}, kIsTensor, kIsTensor, {}, {},
+ error::INVALID_ARGUMENT,
+ "Broadcasting beyond batch dimension is not supported "
+ "(tensor #dims 3 vs broadcast #dims 4)");
+ symmetric_test({1, 3, 4}, {2, 1, 4}, kIsTensor, kIsTensor, {1, 3, 4},
+ {2, 1, 4});
+ symmetric_test({1, 1, 1}, {1, 1, 1, 1}, kIsTensor, kIsTensor, {}, {},
+ error::INVALID_ARGUMENT,
+ "Broadcasting beyond batch dimension is not supported "
+ "(tensor #dims 4 vs broadcast #dims 5)");
+}
+
// Class to test various op converters, using both a TrtNodeValidator and
// Converter.
class OpConverterTest : public ::testing::Test {
@@ -704,7 +923,9 @@
// Reset the validator and converter.
validator_.reset(new TrtNodeValidator);
- converter_.reset(new Converter(network_.get(), /*fp16=*/false));
+ converter_.reset(new Converter(network_.get(),
+ /*precision_mode=*/FP32MODE,
+ /*use_calibration=*/false));
// Reset other related artifacts.
scope_ = Scope::NewRootScope();
@@ -712,8 +933,11 @@
}
// TODO(laigd): test fp16 and int8 support.
- void BuildAndRun(const char* input_name, const std::vector<float>& input_data,
- const char* output_name, std::vector<float>* output_data) {
+ template <typename T>
+ void BuildAndRun(
+ const std::vector<std::pair<const char*, const std::vector<T>>>&
+ input_data,
+ const char* output_name, std::vector<T>* output_data) {
// Mark the output tensor as TRT engine output.
TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(
{{string(output_name), string(output_name)}}));
@@ -724,25 +948,33 @@
CHECK_NOTNULL(engine_.get());
// Execute the TRT engine.
- const int input_size = input_data.size() * sizeof(float);
- const int output_size = output_data->size() * sizeof(float);
- const int input_index = engine_->getBindingIndex(input_name);
- const int output_index = engine_->getBindingIndex(output_name);
+ ASSERT_LE(input_data.size() + 1, 3);
+ void* buffers[3];
+ for (const auto name_and_data : input_data) {
+ const int input_size = name_and_data.second.size() * sizeof(T);
+ const int input_index = engine_->getBindingIndex(name_and_data.first);
+ ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size));
+ ASSERT_EQ(
+ 0, cudaMemcpyAsync(buffers[input_index], name_and_data.second.data(),
+ input_size, cudaMemcpyHostToDevice, stream_));
+ }
- ASSERT_EQ(engine_->getNbBindings(), 2);
- void* buffers[2];
- ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size));
+ const int output_size = output_data->size() * sizeof(T);
+ const int output_index = engine_->getBindingIndex(output_name);
ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size));
- ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input_data.data(),
- input_size, cudaMemcpyHostToDevice, stream_));
+
+ ASSERT_EQ(engine_->getNbBindings(), input_data.size() + 1);
+
TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context(
engine_->createExecutionContext());
execution_context->enqueue(/*batchSize=*/1, buffers, stream_, nullptr);
ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index],
output_size, cudaMemcpyDeviceToHost, stream_));
cudaStreamSynchronize(stream_);
- ASSERT_EQ(0, cudaFree(buffers[input_index]));
- ASSERT_EQ(0, cudaFree(buffers[output_index]));
+
+ for (int i = 0; i < input_data.size() + 1; ++i) {
+ ASSERT_EQ(0, cudaFree(buffers[i]));
+ }
}
bool HasStaticShape(const nvinfer1::Dims& dims) const {
@@ -757,18 +989,7 @@
void AddTestTensor(
const char* name, const std::vector<int32>& dims, int batch_size = 1,
nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT) {
- DataType tf_dtype = DT_FLOAT;
- switch (trt_dtype) {
- case nvinfer1::DataType::kFLOAT:
- tf_dtype = DT_FLOAT;
- break;
- case nvinfer1::DataType::kINT32:
- tf_dtype = DT_INT32;
- break;
- default:
- ASSERT_TRUE(false) << "Unexpected data type "
- << static_cast<int>(trt_dtype);
- }
+ DataType tf_dtype = TrtDataTypeToTf(trt_dtype);
ops::Placeholder::Attrs attrs;
TF_EXPECT_OK(TensorShapeUtils::MakeShape(dims, &attrs.shape_));
attrs.shape_.InsertDim(0, batch_size);
@@ -847,6 +1068,11 @@
}
}
+ // Expose quantization_ranges_ for tests
+ std::unordered_map<nvinfer1::ITensor*, float>& quantization_ranges() {
+ return converter_->quantization_ranges_;
+ }
+
std::unique_ptr<Converter> converter_;
std::unique_ptr<TrtNodeValidator> validator_;
@@ -856,6 +1082,11 @@
TrtUniquePtrType<nvinfer1::INetworkDefinition> network_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
cudaStream_t stream_;
+ // Used to create placeholders with shape and data type information. The
+ // created placeholders will be used as inputs to the node to be verified,
+ // thus we need the shape and data type information to get a non-empty
+ // GraphProperties.
+ // TODO(laigd): consider use this Scope to create the NodeDef to verify.
Scope scope_;
std::unordered_map<string, NodeDef> validator_inputs_;
};
@@ -979,15 +1210,15 @@
Reset();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("weights", {4}, {0, 3, 1, 2});
- RunConversion(node_def);
+ RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_transpose", &output));
EXPECT_TRUE(output.is_tensor());
- EXPECT_TRUE(TrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions()))
- << output.DebugString();
+ ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions());
std::vector<float> output_data(6);
- BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_transpose", &output_data);
+ BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_transpose",
+ &output_data);
EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6));
}
}
@@ -1069,15 +1300,15 @@
Reset();
AddTestTensor("input", ok_params[i].tensor_dims, ok_params[i].batch_size);
AddTestWeights<int32>("weights", {4}, ok_params[i].shape);
- RunConversion(node_def);
+ RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_reshape", &output));
EXPECT_TRUE(output.is_tensor());
- EXPECT_TRUE(TrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions()))
- << output.DebugString();
+ ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions());
std::vector<float> output_data(6);
- BuildAndRun("input", {1, 2, 3, 4, 5, 6}, "my_reshape", &output_data);
+ BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_reshape",
+ &output_data);
EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6));
}
}
@@ -1132,15 +1363,14 @@
get_matmul_nodedef(DT_FLOAT, /*transpose_a=*/false, transpose_b);
AddTestTensor("input", {2}, /*batch_size=*/1);
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
- RunConversion(node_def);
+ RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
EXPECT_TRUE(output.is_tensor());
- EXPECT_TRUE(TrtDimsEqualsArray({2}, output.tensor()->getDimensions()))
- << output.DebugString();
+ ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
std::vector<float> output_data(2);
- BuildAndRun("input", {0, 1}, "my_matmul", &output_data);
+ BuildAndRun<float>({{"input", {0, 1}}}, "my_matmul", &output_data);
if (transpose_b) {
EXPECT_THAT(output_data, ElementsAre(1, 3));
} else {
@@ -1176,12 +1406,15 @@
dims_array[0] = 2;
dims_array[trt_input_rank - 1] = 3;
}
- test->AddTestTensor("input", dims_array, /*batch_size=*/1);
+ test->AddTestTensor("input", dims_array, /*batch_size=*/1,
+ TfDataTypeToTrt(dtype));
// Add bias weights.
const int channel_size = (data_format == "NHWC" ? 3 : 2);
std::vector<CType> bias(channel_size);
- std::iota(bias.begin(), bias.end(), 1); // bias will be {1, 2, 3, ...}
+ for (int i = 0; i < channel_size; ++i) {
+ bias[i] = CType(i + 1); // bias will be {1, 2, 3, ...}
+ }
test->AddTestWeights<CType>("weights", {channel_size}, bias);
// Run the conversion.
@@ -1189,28 +1422,29 @@
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_biasadd", &output));
EXPECT_TRUE(output.is_tensor());
- EXPECT_TRUE(
- TrtDimsEqualsArray(dims_array, output.tensor()->getDimensions()))
- << output.DebugString();
+ ExpectTrtDimsEqualsArray(dims_array, output.tensor()->getDimensions());
// Build and run the engine.
const int num_input = TrtDimsNumElements(GetTestDims(dims_array));
ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2),
num_input);
std::vector<CType> output_data(num_input);
- test->BuildAndRun("input", std::vector<CType>(num_input, CType(0)),
- "my_biasadd", &output_data);
+ test->BuildAndRun<CType>(
+ {{"input", std::vector<CType>(num_input, CType(0))}}, "my_biasadd",
+ &output_data);
if (trt_input_rank == 1) {
if (data_format == "NHWC") {
- EXPECT_THAT(output_data, ElementsAre(1, 2, 3));
+ EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3)));
} else {
- EXPECT_THAT(output_data, ElementsAre(1, 2));
+ EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2)));
}
} else {
if (data_format == "NHWC") {
- EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 1, 2, 3));
+ EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3),
+ CType(1), CType(2), CType(3)));
} else {
- EXPECT_THAT(output_data, ElementsAre(1, 1, 1, 2, 2, 2));
+ EXPECT_THAT(output_data, ElementsAre(CType(1), CType(1), CType(1),
+ CType(2), CType(2), CType(2)));
}
}
}
@@ -1226,11 +1460,508 @@
"Input expects tensor and weights, at my_biasadd");
}
- // OK.
+ // OK. Note that kINT32 is not supported by IScaleLayer, so we don't test
+ // DT_INT32 type here.
TestConvertBiasAdd<DT_FLOAT>(this);
- // TODO(laigd): uncomment this after cl/220663893 is submitted.
- // TestConvertBiasAdd<DT_INT32>(this);
- // TestConvertBiasAdd<DT_HALF>(this);
+ TestConvertBiasAdd<DT_HALF>(this);
+}
+
+template <typename OpType>
+NodeDef GetBinaryOpNodeDef(const string& input_name_l,
+ const string& input_name_r, DataType dtype) {
+ Scope s = Scope::NewRootScope();
+ auto input_l = ops::Placeholder(s.WithOpName(input_name_l), dtype);
+ auto input_r = ops::Placeholder(s.WithOpName(input_name_r), dtype);
+ auto op = OpType(s.WithOpName("my_binary"), input_l, input_r);
+ return op.operation.node()->def();
+}
+
+void CheckAddedLayers(OpConverterTest* test, bool expect_scale_layer) {
+ bool element_wise_layer_found = false;
+ bool scale_layer_found = false;
+ for (int i = 0; i < test->converter_->network()->getNbLayers(); i++) {
+ nvinfer1::ILayer* layer = test->converter_->network()->getLayer(i);
+ if (dynamic_cast<nvinfer1::IScaleLayer*>(layer)) {
+ scale_layer_found = true;
+ } else if (dynamic_cast<nvinfer1::IElementWiseLayer*>(layer)) {
+ element_wise_layer_found = true;
+ }
+ }
+ EXPECT_EQ(expect_scale_layer, scale_layer_found);
+ EXPECT_NE(expect_scale_layer, element_wise_layer_found);
+}
+
+template <typename OpType, DataType dtype>
+void TestBinaryTensorOpWeightNoBroadcast(OpConverterTest* test) {
+ typedef typename EnumToDataType<dtype>::Type CType;
+ for (auto swap_inputs : {false, true}) {
+ test->Reset();
+ NodeDef node_def;
+ if (swap_inputs) {
+ node_def = GetBinaryOpNodeDef<OpType>("weights", "input", dtype);
+ } else {
+ node_def = GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
+ }
+
+ const std::vector<CType> operand1{CType(3), CType(7.5)};
+ const std::vector<CType> operand2{CType(2), CType(3)};
+
+ // It requires the dims to be at least of rank 3 to apply an IScaleLayer.
+ test->AddTestTensor("input", /*dims=*/{1, 1, 2}, /*batch_size=*/1,
+ TfDataTypeToTrt(dtype));
+ test->AddTestWeights<CType>("weights", /*dims=*/{1, 1, 2},
+ /*values=*/swap_inputs ? operand1 : operand2);
+ test->RunValidationAndConversion(node_def);
+
+ // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
+ CheckAddedLayers(test, /*expect_scale_layer=*/true);
+
+ // Check the dims of the output ITensor.
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
+ EXPECT_TRUE(output.is_tensor());
+ ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions());
+
+ std::vector<CType> output_data(2);
+ test->BuildAndRun<CType>(
+ {{"input",
+ /*input_data=*/swap_inputs ? operand2 : operand1}},
+ "my_binary", &output_data);
+ if (node_def.op() == "Add") {
+ EXPECT_THAT(output_data, ElementsAre(CType(5), CType(10.5)));
+ } else if (node_def.op() == "Sub") {
+ EXPECT_THAT(output_data, ElementsAre(CType(1), CType(4.5)));
+ } else if (node_def.op() == "Mul") {
+ EXPECT_THAT(output_data, ElementsAre(CType(6), CType(22.5)));
+ } else if (node_def.op() == "Div") {
+ EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5)));
+ } else if (node_def.op() == "RealDiv") {
+ EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5)));
+ } else {
+ ASSERT_TRUE(false);
+ }
+ }
+}
+
+template <DataType dtype>
+void TestBinaryTensorOpWeightWithChannelWiseBroadcast(OpConverterTest* test) {
+ typedef typename EnumToDataType<dtype>::Type CType;
+ const NodeDef node_def =
+ GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
+ const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
+ const std::vector<CType> weights{CType(10), CType(20)};
+ // There are two types of valid dim pairs which requires channel-wise
+ // broadcasting:
+ // - input dims (X Y Z) vs weights dims (X 1 1)
+ // - input dims (X Y Z) vs weights dims (Z)
+ // Here X=Z=2 and Y=1.
+ for (auto weights_dims : std::vector<std::vector<int>>{{2, 1, 1}, {2}}) {
+ test->Reset();
+ test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
+ TfDataTypeToTrt(dtype));
+ test->AddTestWeights<CType>("weights", weights_dims, weights);
+ test->RunValidationAndConversion(node_def);
+
+ // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
+ CheckAddedLayers(test, /*expect_scale_layer=*/true);
+
+ // Check the dims of the output ITensor.
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
+ EXPECT_TRUE(output.is_tensor());
+ ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
+
+ std::vector<CType> output_data(4);
+ test->BuildAndRun<CType>({{"input", input}}, "my_binary", &output_data);
+ if (weights_dims.size() == 1) {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(11), CType(22), CType(13), CType(24)));
+ } else {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(11), CType(12), CType(23), CType(24)));
+ }
+ }
+}
+
+template <DataType dtype>
+void TestBinaryTensorOpWeightWithUniformlyBroadcast(OpConverterTest* test) {
+ typedef typename EnumToDataType<dtype>::Type CType;
+ const NodeDef node_def =
+ GetBinaryOpNodeDef<ops::Add>("input", "weights", dtype);
+ const std::vector<CType> input{CType(1), CType(2), CType(3), CType(4)};
+ const std::vector<CType> weights{CType(10)};
+ test->Reset();
+ test->AddTestTensor("input", /*dims=*/{2, 1, 2}, /*batch_size=*/1,
+ TfDataTypeToTrt(dtype));
+ test->AddTestWeights<CType>("weights", {1, 1, 1, 1}, weights);
+ test->RunValidationAndConversion(node_def);
+
+ // Make sure it does use BinaryTensorOpWeight, not BinaryTensorOpTensor.
+ CheckAddedLayers(test, /*expect_scale_layer=*/true);
+
+ // Check the dims of the output ITensor.
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
+ EXPECT_TRUE(output.is_tensor());
+ ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
+
+ std::vector<CType> output_data(4);
+ test->BuildAndRun<CType>({{"input", input}}, "my_binary", &output_data);
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(11), CType(12), CType(13), CType(14)));
+}
+
+template <typename OpType>
+void TestBinaryTensorOpWeightFallback(OpConverterTest* test,
+ const std::vector<int32>& input_dims,
+ const std::vector<int>& weights_dims,
+ error::Code code = error::OK,
+ const char* error_msg_substr = nullptr,
+ const int input_batch_size = 1) {
+ const DataType dtype = DT_FLOAT;
+ typedef typename EnumToDataType<dtype>::Type CType;
+ const size_t num_inputs = TrtDimsNumElements(GetTestDims(input_dims));
+ const size_t num_weights = TrtDimsNumElements(GetTestDims(weights_dims));
+
+ test->Reset();
+ const NodeDef node_def =
+ GetBinaryOpNodeDef<OpType>("input", "weights", dtype);
+ test->AddTestTensor("input", /*dims=*/input_dims, input_batch_size,
+ TfDataTypeToTrt(dtype));
+ test->AddTestWeights<CType>(
+ "weights", /*dims=*/weights_dims,
+ /*values=*/std::vector<CType>(num_weights, CType(1)));
+ test->RunValidationAndConversion(node_def, code, error_msg_substr);
+ if (code != error::OK) return;
+
+ // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
+ CheckAddedLayers(test, /*expect_scale_layer=*/false);
+
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
+ EXPECT_TRUE(output.is_tensor());
+
+ // Check the dims of the output ITensor.
+ std::vector<int> expected_output_dims = input_dims;
+ for (int i = expected_output_dims.size() - 1, j = weights_dims.size() - 1;
+ i >= 0 && j >= 0; --i, --j) {
+ if (expected_output_dims[i] == 1) {
+ expected_output_dims[i] = weights_dims[j];
+ }
+ }
+ ExpectTrtDimsEqualsArray(expected_output_dims,
+ output.tensor()->getDimensions());
+
+ // Check the result of running the engine.
+ const int expected_num_outputs =
+ TrtDimsNumElements(GetTestDims(expected_output_dims));
+ std::vector<CType> output_data(expected_num_outputs);
+ test->BuildAndRun<CType>(
+ {{"input",
+ /*input_data=*/std::vector<CType>(num_inputs, CType(2))}},
+ "my_binary", &output_data);
+ if (node_def.op() == "Add") {
+ EXPECT_THAT(output_data, ElementsAreArray(std::vector<CType>(
+ expected_num_outputs, CType(3))));
+ } else if (node_def.op() == "Minimum") {
+ EXPECT_THAT(output_data, ElementsAreArray(std::vector<CType>(
+ expected_num_outputs, CType(1))));
+ } else {
+ ASSERT_TRUE(false);
+ }
+}
+
+template <typename OpType, DataType dtype>
+void TestBinaryTensorOpTensor(OpConverterTest* test) {
+ typedef typename EnumToDataType<dtype>::Type CType;
+ test->Reset();
+ const NodeDef node_def =
+ GetBinaryOpNodeDef<OpType>("input1", "input2", dtype);
+ test->AddTestTensor("input1", /*dims=*/{1, 2}, /*batch_size=*/1,
+ TfDataTypeToTrt(dtype));
+ test->AddTestTensor("input2", /*dims=*/{2, 1}, /*batch_size=*/1,
+ TfDataTypeToTrt(dtype));
+ test->RunValidationAndConversion(node_def);
+
+ // Make sure it does use BinaryTensorOpTensor, not BinaryTensorOpWeight.
+ CheckAddedLayers(test, /*expect_scale_layer=*/false);
+
+ // Check output dims.
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(test->GetTensorOrWeights("my_binary", &output));
+ EXPECT_TRUE(output.is_tensor());
+ ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
+
+ std::vector<CType> output_data(4);
+ // After broadcasting first input becomes {3, 6, 3, 6} and second input
+ // becomes {2, 3, 2, 3}.
+ test->BuildAndRun<CType>(
+ {{"input1", {CType(3), CType(6)}}, {"input2", {CType(2), CType(3)}}},
+ "my_binary", &output_data);
+ if (node_def.op() == "Add") {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(5), CType(8), CType(6), CType(9)));
+ } else if (node_def.op() == "Sub") {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(1), CType(4), CType(0), CType(3)));
+ } else if (node_def.op() == "Mul") {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(6), CType(12), CType(9), CType(18)));
+ } else if (node_def.op() == "Div") {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
+ } else if (node_def.op() == "RealDiv") {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
+ } else if (node_def.op() == "Minimum") {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(2), CType(2), CType(3), CType(3)));
+ } else if (node_def.op() == "Maximum") {
+ EXPECT_THAT(output_data,
+ ElementsAre(CType(3), CType(6), CType(3), CType(6)));
+ } else {
+ ASSERT_TRUE(false);
+ }
+}
+
+TEST_F(OpConverterTest, ConvertBinary) {
+ // Input size doesn't match, should fail.
+ for (size_t num_inputs = 0; num_inputs < 2; ++num_inputs) {
+ Reset();
+ NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"});
+ AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT);
+ RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
+ "Binary ops require two inputs, at my_add");
+ }
+ {
+ // Both inputs are weights.
+ Reset();
+ NodeDef node_def = MakeNodeDef("my_add", "Add", {"weights1", "weights2"});
+ AddTestWeights<float>("weights1", {1}, {1});
+ AddTestWeights<float>("weights2", {1}, {1});
+ RunValidationAndConversion(
+ node_def, error::UNIMPLEMENTED,
+ "Constant folding is falled back to TensorFlow, binary op received "
+ "both input as constant at: my_add");
+ }
+
+ // Test BinaryTensorOpWeight() without broadcasting.
+ TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_FLOAT>(this);
+ TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_FLOAT>(this);
+ TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_FLOAT>(this);
+ TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_FLOAT>(this);
+ TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_FLOAT>(this);
+#if 0
+ // TODO(b/119560144): it doesn't support FP16 constants and the following test
+ // will fail.
+ TestBinaryTensorOpWeightNoBroadcast<ops::Add, DT_HALF>(this);
+ TestBinaryTensorOpWeightNoBroadcast<ops::Sub, DT_HALF>(this);
+ TestBinaryTensorOpWeightNoBroadcast<ops::Mul, DT_HALF>(this);
+ TestBinaryTensorOpWeightNoBroadcast<ops::Div, DT_HALF>(this);
+ TestBinaryTensorOpWeightNoBroadcast<ops::RealDiv, DT_HALF>(this);
+#endif
+
+ // Test BinaryTensorOpWeight() with channel-wise broadcasting.
+ TestBinaryTensorOpWeightWithChannelWiseBroadcast<DT_FLOAT>(this);
+
+ // Test BinaryTensorOpWeight() with uniformly broadcasting.
+ TestBinaryTensorOpWeightWithUniformlyBroadcast<DT_FLOAT>(this);
+
+ // Test BinaryTensorOpWeight() falling back to BinaryTensorOpTensor().
+ // Unsupported op.
+ TestBinaryTensorOpWeightFallback<ops::Minimum>(this, {1, 1, 1}, {1});
+ // Rank of input tensor dimension <3.
+ TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1}, {1});
+ // Broadcast on batch dimension, should fail.
+ TestBinaryTensorOpWeightFallback<ops::Add>(
+ this, {1, 1, 1}, {2, 1, 1, 1}, error::INVALID_ARGUMENT,
+ "Unsupported binary op broadcast scheme for op my_binary",
+ /*input_batch_size=*/2);
+ // Incompatible dims with per-channel mode.
+ TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 1, 1}, {1, 2, 1});
+ // Incompatible dims.
+ TestBinaryTensorOpWeightFallback<ops::Add>(this, {1, 2, 1}, {2});
+
+ // Test BinaryTensorOpTensor() with broadcasting.
+ TestBinaryTensorOpTensor<ops::Add, DT_FLOAT>(this);
+ TestBinaryTensorOpTensor<ops::Sub, DT_FLOAT>(this);
+ TestBinaryTensorOpTensor<ops::Mul, DT_FLOAT>(this);
+ TestBinaryTensorOpTensor<ops::Div, DT_FLOAT>(this);
+ TestBinaryTensorOpTensor<ops::RealDiv, DT_FLOAT>(this);
+ TestBinaryTensorOpTensor<ops::Minimum, DT_FLOAT>(this);
+ TestBinaryTensorOpTensor<ops::Maximum, DT_FLOAT>(this);
+
+ TestBinaryTensorOpTensor<ops::Add, DT_HALF>(this);
+ TestBinaryTensorOpTensor<ops::Sub, DT_HALF>(this);
+ TestBinaryTensorOpTensor<ops::Mul, DT_HALF>(this);
+ TestBinaryTensorOpTensor<ops::Div, DT_HALF>(this);
+ TestBinaryTensorOpTensor<ops::RealDiv, DT_HALF>(this);
+ TestBinaryTensorOpTensor<ops::Minimum, DT_HALF>(this);
+ TestBinaryTensorOpTensor<ops::Maximum, DT_HALF>(this);
+}
+
+TEST_F(OpConverterTest, ConvertQuantize) {
+ for (const string& op :
+ {"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars",
+ "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"}) {
+ // Input list is empty, should fail.
+ NodeDef node_def = MakeNodeDef("my_quantize", op, {});
+ RunValidationAndConversion(
+ node_def, error::INVALID_ARGUMENT,
+ StrCat("Invalid number of inputs for ", op, ", at my_quantize")
+ .c_str());
+ }
+ {
+ // FakeQuantWithMinMaxArgs attributes are empty, should fail.
+ NodeDef node_def =
+ MakeNodeDef("my_quantize", "FakeQuantWithMinMaxArgs", {"input"});
+ AddTestTensor("input", {1, 2, 3});
+ RunValidationAndConversion(
+ node_def, error::INVALID_ARGUMENT,
+ "Min or max attribute not found for FakeQuantWithMinMaxArgs "
+ "at my_quantize");
+ }
+ {
+ // FakeQuantWithMinMaxArgs ranges set via attributes, ok.
+ Reset();
+ Scope s = Scope::NewRootScope();
+ auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
+ auto quantize_attrs = ops::FakeQuantWithMinMaxArgs::Min(-6.0f).Max(6.0f);
+ auto quantize = ops::FakeQuantWithMinMaxArgs(s.WithOpName("my_quantize"),
+ input, quantize_attrs);
+ const NodeDef& node_def = quantize.operation.node()->def();
+ AddTestTensor("input", {1, 2, 3});
+ RunValidationAndConversion(node_def);
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
+ EXPECT_TRUE(output.is_tensor());
+ auto ranges = quantization_ranges();
+ EXPECT_EQ(1, ranges.count(output.tensor()));
+ EXPECT_EQ(6.0f, ranges[output.tensor()]);
+ }
+ {
+ // FakeQuantWithMinMaxVars ranges set via inputs, ok.
+ Reset();
+ Scope s = Scope::NewRootScope();
+ auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
+ auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
+ auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
+ auto quantize = ops::FakeQuantWithMinMaxVars(
+ s.WithOpName("my_quantize"), input, weights_min, weights_max);
+ const NodeDef& node_def = quantize.operation.node()->def();
+ AddTestTensor("input", {1, 2, 3});
+ AddTestWeights<float>("weights_min", {1}, {-6.0f});
+ AddTestWeights<float>("weights_max", {1}, {6.0f});
+ RunValidationAndConversion(node_def);
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
+ EXPECT_TRUE(output.is_tensor());
+ auto ranges = quantization_ranges();
+ EXPECT_EQ(1, ranges.count(output.tensor()));
+ EXPECT_EQ(6.0f, ranges[output.tensor()]);
+ }
+ {
+ // QuantizeAndDequantizeV2 ranges set via inputs, ok.
+ Reset();
+ Scope s = Scope::NewRootScope();
+ auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
+ auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
+ auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
+ auto quantize = ops::QuantizeAndDequantizeV2(
+ s.WithOpName("my_quantize"), input, weights_min, weights_max);
+ const NodeDef& node_def = quantize.operation.node()->def();
+ AddTestTensor("input", {1, 2, 3});
+ AddTestWeights<float>("weights_min", {1}, {-6.0f});
+ AddTestWeights<float>("weights_max", {1}, {6.0f});
+ RunValidationAndConversion(node_def);
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
+ EXPECT_TRUE(output.is_tensor());
+ auto ranges = quantization_ranges();
+ EXPECT_EQ(1, ranges.count(output.tensor()));
+ EXPECT_EQ(6.0f, ranges[output.tensor()]);
+ }
+ {
+ // QuantizeAndDequantizeV2 Range inputs are tensors, should fail.
+ Reset();
+ Scope s = Scope::NewRootScope();
+ auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
+ auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
+ auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
+ auto quantize = ops::QuantizeAndDequantizeV2(
+ s.WithOpName("my_quantize"), input, weights_min, weights_max);
+ const NodeDef& node_def = quantize.operation.node()->def();
+ AddTestTensor("input", {1, 2, 3});
+ AddTestTensor("weights_min", {1});
+ AddTestTensor("weights_max", {1});
+ RunValidationAndConversion(
+ node_def, error::INVALID_ARGUMENT,
+ "Min and max inputs for QuantizeAndDequantizeV2 must be weights not "
+ "tensors, at my_quantize");
+ }
+ {
+ // QuantizeAndDequantizeV3 ranges set via inputs, ok.
+ Reset();
+ Scope s = Scope::NewRootScope();
+ auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
+ auto weights_min = ops::Placeholder(s.WithOpName("weights_min"), DT_FLOAT);
+ auto weights_max = ops::Placeholder(s.WithOpName("weights_max"), DT_FLOAT);
+ auto num_bits = ops::Placeholder(s.WithOpName("num_bits"), DT_INT32);
+ auto quantize = ops::QuantizeAndDequantizeV3(
+ s.WithOpName("my_quantize"), input, weights_min, weights_max, num_bits);
+ const NodeDef& node_def = quantize.operation.node()->def();
+ AddTestTensor("input", {1, 2, 3});
+ AddTestWeights<float>("weights_min", {1}, {-6.0f});
+ AddTestWeights<float>("weights_max", {1}, {6.0f});
+ AddTestWeights<int>("num_bits", {1}, {8});
+ RunValidationAndConversion(node_def);
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(GetTensorOrWeights("my_quantize", &output));
+ EXPECT_TRUE(output.is_tensor());
+ auto ranges = quantization_ranges();
+ EXPECT_EQ(1, ranges.count(output.tensor()));
+ EXPECT_EQ(6.0f, ranges[output.tensor()]);
+ }
+}
+
+TEST_F(OpConverterTest, ConvertRelu6) {
+ {
+ // Input list is empty, should fail.
+ NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {});
+ RunValidationAndConversion(
+ node_def, error::INVALID_ARGUMENT,
+ "Invalid number of inputs for Relu6, at my_relu6");
+ }
+
+ // Get the NodeDef for Relu6.
+ Scope s = Scope::NewRootScope();
+ auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
+ auto relu6 = ops::Relu6(s.WithOpName("my_relu6"), input);
+ const NodeDef node_def = relu6.operation.node()->def();
+ {
+ // Input is weights, should fail.
+ Reset();
+ AddTestWeights<float>("input", {1}, {1.0f});
+ RunValidationAndConversion(
+ node_def, error::UNIMPLEMENTED,
+ "Relu6 is only implemented for tensors, not weights, at my_relu6");
+ }
+ {
+ // Clip tensor values and set quantization ranges, ok.
+ Reset();
+ AddTestTensor("input", {1, 2, 3});
+ RunValidationAndConversion(node_def);
+ TRT_TensorOrWeights output;
+ TF_EXPECT_OK(GetTensorOrWeights("my_relu6", &output));
+ EXPECT_TRUE(output.is_tensor());
+ auto ranges = quantization_ranges();
+ EXPECT_EQ(ranges[output.tensor()], 6.0f);
+
+ std::vector<float> output_data(6);
+ BuildAndRun<float>({{"input", {-100, -1, 0, 3, 5, 9}}}, "my_relu6",
+ &output_data);
+ EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6));
+ }
}
} // namespace convert
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index b30d94b..4ac7e21 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -67,6 +67,9 @@
TF_RETURN_IF_ERROR(GetPrecisionMode(
Uppercase(params.at("precision_mode").s()), &precision_mode_));
}
+ if (params.count("use_calibration")) {
+ use_calibration_ = params.at("use_calibration").b();
+ }
return tensorflow::Status::OK();
}
@@ -222,6 +225,12 @@
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
tensorflow::tensorrt::convert::ConversionParams cp;
+ if (use_calibration_ && precision_mode_ != INT8MODE) {
+ LOG(ERROR) << "Calibration with FP32 or FP16 is not implemented. "
+ << "Falling back to use_calibration = False.";
+ use_calibration_ = false;
+ }
+
std::vector<string> nodes_to_preserve;
for (const auto& n : item.NodesToPreserve()) {
auto tokens = str_util::Split(n, ":");
@@ -250,6 +259,7 @@
cp.is_dyn_op = is_dynamic_op_;
cp.cached_engine_batches = batches_;
cp.max_cached_engines = max_cached_batches_;
+ cp.use_calibration = use_calibration_;
auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp);
VLOG(1) << "Returning from " << name_;
return status;
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
index 71b51d1..3e8dc09 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.h
@@ -38,7 +38,8 @@
maximum_batch_size_(-1),
is_dynamic_op_(false),
max_cached_batches_(1),
- max_workspace_size_bytes_(256LL << 20) {
+ max_workspace_size_bytes_(256LL << 20),
+ use_calibration_(true) {
VLOG(1) << "Constructing " << name_;
}
@@ -67,6 +68,7 @@
std::vector<int> batches_;
int max_cached_batches_;
int64_t max_workspace_size_bytes_;
+ bool use_calibration_;
};
} // namespace convert
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 0194468..1e907e0 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -124,8 +124,10 @@
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_));
- calibration_mode_ =
- (precision_mode_ == INT8MODE && calibration_data.size() == 0);
+ OP_REQUIRES_OK(context,
+ context->GetAttr("use_calibration", &use_calibration_));
+ calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE &&
+ calibration_data.size() == 0);
if (calibration_data.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
@@ -308,7 +310,7 @@
std::vector<void*> buffers(num_binding);
for (int i = 0; i < ctx->num_inputs(); i++) {
const string input_name = StrCat(kInputPHName, i);
- const size_t binding_index =
+ const int binding_index =
trt_engine_ptr->getBindingIndex(input_name.c_str());
if (binding_index == -1) {
LOG(ERROR) << "Input node not found, at " << input_name;
@@ -345,7 +347,7 @@
for (int i = 0; i < ctx->num_outputs(); i++) {
// Create an output tensor
const string output_name = StrCat(kOutputPHName, i);
- const size_t binding_index =
+ const int binding_index =
trt_engine_ptr->getBindingIndex(output_name.c_str());
Tensor* output_tensor = nullptr;
@@ -497,7 +499,8 @@
// means calibration_mode_ is true and this path won't get executed.
auto status = convert::ConvertGraphDefToEngine(
segment_graph_, precision_mode_, batch_size, workspace_size_, shapes,
- &logger, allocator, calibrator_.get(), &engine, &convert_successfully);
+ &logger, allocator, calibrator_.get(), &engine, use_calibration_,
+ &convert_successfully);
if (!status.ok()) {
if (convert_successfully) {
// This means it fail to build the engine even when the network is built
@@ -586,6 +589,7 @@
*segment_graph, INT8MODE, cres->calibrator_->getBatchSize(),
workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
cres->calibrator_.get(), &cres->engine_,
+ /*use_calibration=*/true,
/*convert_successfully=*/nullptr);
if (!s.ok()) {
LOG(ERROR) << "Calibration failed: " << s;
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 8fe0675..b545f49 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -130,6 +130,10 @@
// The finalized calibrator for inference.
std::unique_ptr<TRTInt8Calibrator> calibrator_;
+
+ // If true, create calibration graph for INT8 mode. Otherwise, we are using
+ // user-provided quantization ranges.
+ bool use_calibration_;
};
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
index e0c7b62..9240590 100644
--- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -16,6 +16,7 @@
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -39,18 +40,19 @@
.Attr("cached_engine_batches: list(int) = []")
.Attr("max_cached_engines_count: int = 1")
.Attr("workspace_size_bytes: int")
- .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}")
+ .Attr("precision_mode: {'FP32', 'FP16', 'INT8'}")
.Attr("calibration_data: string = ''")
+ .Attr("use_calibration: bool = true")
.Input("in_tensor: InT")
- .Output("out_tensor: OutT");
-// TODO(jie): TF requires concrete output shape for concrete input shapes.
-// This is tricky for batch dimension, since we cannot ensure which input
-// would carry the correct batch dimension (for the current stage of the
-// implementation, we do require all input tensor to carry the same batch
-// size, but this could change in the future). Hence we disable shape
-// inference function as a workaround.
-// .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
-
+ .Output("out_tensor: OutT")
+ // TODO(jie): TF requires concrete output shape for concrete input shapes.
+ // This is tricky for batch dimension, since we cannot ensure which input
+ // would carry the correct batch dimension (for the current stage of the
+ // implementation, we do require all input tensor to carry the same batch
+ // size, but this could change in the future). Hence we disable shape
+ // inference function as a workaround.
+ // .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
+ .SetShapeFn(shape_inference::UnknownShape);
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 0e59fdd..203b269 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -70,7 +70,8 @@
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batch_sizes=None):
+ cached_engine_batch_sizes=None,
+ use_calibration=True):
"""Returns a RewriterConfig proto for TRT transformation.
Args:
@@ -95,6 +96,15 @@
use this list to determine the batch sizes of the cached engines, instead
of making the decision on the fly. This is useful when we know the most
common batch size(s) the application is going to generate.
+ use_calibration: this argument is ignored if precision_mode is not INT8. If
+ set to True, a calibration graph will be created to calibrate the missing
+ ranges. The calibration graph must be converted to an inference graph
+ using calib_graph_to_infer_graph() after running calibration. if set to
+ False, quantization nodes will be expected for every tensor in the graph
+ (exlcuding those which will be fused). If a range is missing, an error
+ will occur. Please note that accuracy may be negatively affected if there
+ is a mismatch between which tensors TRT quantizes and which tensors were
+ trained with fake quantization.
Returns:
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
@@ -141,6 +151,7 @@
"maximum_cached_engines items.")
optimizer.parameter_map["cached_engine_batches"].list.i.extend(
cached_engine_batch_sizes)
+ optimizer.parameter_map["use_calibration"].b = use_calibration
return rewriter_config_with_trt
@@ -153,6 +164,7 @@
is_dynamic_op=False,
maximum_cached_engines=1,
cached_engine_batch_sizes=None,
+ use_calibration=True,
input_saved_model_dir=None,
input_saved_model_tags=None,
output_saved_model_dir=None,
@@ -184,6 +196,15 @@
use this list to determine the batch sizes of the cached engines, instead
of making the decision on the fly. This is useful when we know the most
common batch size(s) the application is going to generate.
+ use_calibration: this argument is ignored if precision_mode is not INT8. If
+ set to True, a calibration graph will be created to calibrate the missing
+ ranges. The calibration graph must be converted to an inference graph
+ using calib_graph_to_infer_graph() after running calibration. if set to
+ False, quantization nodes will be expected for every tensor in the graph
+ (exlcuding those which will be fused). If a range is missing, an error
+ will occur. Please note that accuracy may be negatively affected if there
+ is a mismatch between which tensors TRT quantizes and which tensors were
+ trained with fake quantization.
input_saved_model_dir: the directory to load the SavedModel which contains
the input graph to transforms. Used only when input_graph_def is None.
input_saved_model_tags: list of tags to load the SavedModel.
@@ -333,7 +354,7 @@
rewriter_config_with_trt = get_tensorrt_rewriter_config(
rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode,
minimum_segment_size, is_dynamic_op, maximum_cached_engines,
- cached_engine_batch_sizes)
+ cached_engine_batch_sizes, use_calibration)
session_config_with_trt.graph_options.rewrite_options.CopyFrom(
rewriter_config_with_trt)
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index 18096e0..b325d76 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -56,8 +56,9 @@
strides=[1, 2, 2, 1],
padding="SAME",
name="conv")
- bias = constant_op.constant(
- [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype)
+ bias = constant_op.constant([4., 1.5, 2., 3., 5., 7.],
+ name="bias",
+ dtype=dtype)
added = nn.bias_add(conv, bias, name="bias_add")
relu = nn.relu(added, "relu")
identity = array_ops.identity(relu, "identity")
@@ -73,11 +74,12 @@
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""
- # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
- # breaks the connection check, fix it.
- # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add",
- # "relu", "identity", "max_pool"]
- return ["my_trt_op_0"]
+ return {
+ "my_trt_op_0": [
+ "weights", "conv", "bias", "bias_add", "relu", "identity",
+ "max_pool"
+ ]
+ }
class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase):
@@ -92,7 +94,7 @@
g = ops.Graph()
with g.as_default():
inp = array_ops.placeholder(
- dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
+ dtype=dtype, shape=input_dims, name=input_name)
with g.device("/GPU:0"):
conv_filter = constant_op.constant(
[[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
@@ -105,10 +107,10 @@
padding="SAME",
name="conv")
c1 = constant_op.constant(
- np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c1")
+ np.random.randn(12, 12, 6), dtype=dtype, name="c1")
p = math_ops.mul(conv, c1, name="mul")
c2 = constant_op.constant(
- np.random.randn(input_dims[0], 12, 12, 6), dtype=dtype, name="c2")
+ np.random.randn(12, 12, 6), dtype=dtype, name="c2")
q = math_ops.div(conv, c2, name="div")
edge = self.trt_incompatible_op(q, name="incompatible")
@@ -129,22 +131,21 @@
def ExpectedEnginesToBuild(self, run_params):
"""Return the expected engines to build."""
- # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which
- # breaks the connection check, fix it.
- # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1",
- # "add", "sub1"];
- # - my_trt_op_1 should have ["weights","conv", "div"]
- return ["my_trt_op_0", "my_trt_op_1"]
+ return {
+ "my_trt_op_0": [
+ "add", "add1", "c1", "div1", "mul", "mul1", "sub", "sub1"
+ ],
+ "my_trt_op_1": ["c2", "conv", "div", "weights"]
+ }
- def ShouldRunTest(self, run_params):
- # TODO(aaroey): LayoutOptimizer adds Transpose(Const, Const) to the graph
- # which breaks the conversion. We should fix it as:
- # - Detect the invalid NodeDef earlier before adding them to segment
- # - Let it able to change the RewriterConfig when calling
- # create_inference_graph().
- # It will be good to add debugging feature for Grappler to print the graph
- # after running each optimizer.
- return False
+ def GetConversionParams(self, run_params):
+ """Return a ConversionParams for test."""
+ return super(
+ SimpleMultiEnginesTest, self
+ ).GetConversionParams(run_params)._replace(
+ # Disable layout optimizer, since it'll add Transpose(Const, Const) to
+ # the graph and breaks the conversion check.
+ rewriter_config=trt_test.OptimizerDisabledRewriterConfig())
class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
@@ -197,7 +198,9 @@
"""Whether to run the test."""
# Disable the test in fp16 mode since multiple matmul and add ops together
# can cause overflow.
- return run_params.precision_mode != "FP16"
+ return ((run_params.precision_mode != "FP16") and
+ not (trt_test.IsQuantizationMode(run_params.precision_mode) and
+ not run_params.use_calibration))
class PartiallyConvertedTestB(PartiallyConvertedTestA):
diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py
new file mode 100644
index 0000000..e7d6ec4
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py
@@ -0,0 +1,290 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to test TF-TRT INT8 conversion without calibration on Mnist model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tensorrt.python import trt_convert
+# pylint: disable=unused-import
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+# pylint: enable=unused-import
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import data
+from tensorflow.python import keras
+from tensorflow.python.estimator.estimator import Estimator
+from tensorflow.python.estimator.model_fn import EstimatorSpec
+from tensorflow.python.estimator.model_fn import ModeKeys
+from tensorflow.python.estimator.run_config import RunConfig
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.datasets import mnist
+from tensorflow.python.layers import layers
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.summary import summary
+from tensorflow.python.training import saver
+from tensorflow.python.training.adam import AdamOptimizer
+from tensorflow.python.training.checkpoint_management import latest_checkpoint
+from tensorflow.python.training.training_util import get_global_step
+
+INPUT_NODE_NAME = 'input'
+OUTPUT_NODE_NAME = 'output'
+
+
+class QuantizationAwareTrainingMNISTTest(test_util.TensorFlowTestCase):
+
+ def _BuildGraph(self, x):
+
+ def _Quantize(x, r):
+ x = gen_array_ops.quantize_and_dequantize_v2(x, -r, r)
+ return x
+
+ def _DenseLayer(x, num_inputs, num_outputs, quantization_range, name):
+ """Dense layer with quantized outputs.
+
+ Args:
+ x: input to the dense layer
+ num_inputs: number of input columns of x
+ num_outputs: number of output columns
+ quantization_range: the min/max range for quantization
+ name: name of the variable scope
+
+ Returns:
+ The output of the layer.
+ """
+ with variable_scope.variable_scope(name):
+ kernel = variable_scope.get_variable(
+ 'kernel',
+ shape=[num_inputs, num_outputs],
+ dtype=dtypes.float32,
+ initializer=keras.initializers.glorot_uniform())
+ bias = variable_scope.get_variable(
+ 'bias',
+ shape=[num_outputs],
+ dtype=dtypes.float32,
+ initializer=keras.initializers.zeros())
+ x = math_ops.matmul(x, kernel)
+ x = _Quantize(x, quantization_range)
+ x = nn.bias_add(x, bias)
+ x = _Quantize(x, quantization_range)
+ return x
+
+ x = _Quantize(x, 1)
+ # Conv + Bias + Relu6
+ x = layers.conv2d(x, filters=32, kernel_size=3, use_bias=True)
+ x = nn.relu6(x)
+ # Conv + Bias + Relu6
+ x = layers.conv2d(x, filters=64, kernel_size=3, use_bias=True)
+ x = nn.relu6(x)
+ # Reduce
+ x = math_ops.reduce_mean(x, [1, 2])
+ x = _Quantize(x, 6)
+ # FC1
+ x = _DenseLayer(x, 64, 512, 6, name='dense')
+ x = nn.relu6(x)
+ # FC2
+ x = _DenseLayer(x, 512, 10, 25, name='dense_1')
+ x = array_ops.identity(x, name=OUTPUT_NODE_NAME)
+ return x
+
+ def _GetGraphDef(self, use_trt, max_batch_size, model_dir):
+ """Get the frozen mnist GraphDef.
+
+ Args:
+ use_trt: whether use TF-TRT to convert the graph.
+ max_batch_size: the max batch size to apply during TF-TRT conversion.
+ model_dir: the model directory to load the checkpoints.
+
+ Returns:
+ The frozen mnist GraphDef.
+ """
+ graph = ops.Graph()
+ with self.session(graph=graph) as sess:
+ with graph.device('/GPU:0'):
+ x = array_ops.placeholder(
+ shape=(None, 28, 28, 1), dtype=dtypes.float32, name=INPUT_NODE_NAME)
+ self._BuildGraph(x)
+ # Load weights
+ mnist_saver = saver.Saver()
+ checkpoint_file = latest_checkpoint(model_dir)
+ mnist_saver.restore(sess, checkpoint_file)
+ # Freeze
+ graph_def = graph_util.convert_variables_to_constants(
+ sess, sess.graph_def, output_node_names=[OUTPUT_NODE_NAME])
+ # Convert with TF-TRT
+ if use_trt:
+ logging.info('Number of nodes before TF-TRT conversion: %d',
+ len(graph_def.node))
+ graph_def = trt_convert.create_inference_graph(
+ graph_def,
+ outputs=[OUTPUT_NODE_NAME],
+ max_batch_size=max_batch_size,
+ precision_mode='INT8',
+ max_workspace_size_bytes=4096 << 19,
+ minimum_segment_size=2,
+ use_calibration=False,
+ )
+ logging.info('Number of nodes after TF-TRT conversion: %d',
+ len(graph_def.node))
+ num_engines = len(
+ [1 for n in graph_def.node if str(n.op) == 'TRTEngineOp'])
+ self.assertEqual(1, num_engines)
+ return graph_def
+
+ def _Run(self, is_training, use_trt, batch_size, num_epochs, model_dir):
+ """Train or evaluate the model.
+
+ Args:
+ is_training: whether to train or evaluate the model. In training mode,
+ quantization will be simulated where the quantize_and_dequantize_v2 are
+ placed.
+ use_trt: if true, use TRT INT8 mode for evaluation, which will perform
+ real quantization. Otherwise use native TensorFlow which will perform
+ simulated quantization. Ignored if is_training is True.
+ batch_size: batch size.
+ num_epochs: how many epochs to train. Ignored if is_training is False.
+ model_dir: where to save or load checkpoint.
+
+ Returns:
+ The Estimator evaluation result.
+ """
+ # Get dataset
+ train_data, test_data = mnist.load_data()
+
+ def _PreprocessFn(x, y):
+ x = math_ops.cast(x, dtypes.float32)
+ x = array_ops.expand_dims(x, axis=2)
+ x = 2.0 * (x / 255.0) - 1.0
+ y = math_ops.cast(y, dtypes.int32)
+ return x, y
+
+ def _EvalInputFn():
+ mnist_x, mnist_y = test_data
+ dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y))
+ dataset = dataset.apply(
+ data.experimental.map_and_batch(
+ map_func=_PreprocessFn,
+ batch_size=batch_size,
+ num_parallel_calls=8))
+ dataset = dataset.repeat(count=1)
+ iterator = dataset.make_one_shot_iterator()
+ features, labels = iterator.get_next()
+ return features, labels
+
+ def _TrainInputFn():
+ mnist_x, mnist_y = train_data
+ dataset = data.Dataset.from_tensor_slices((mnist_x, mnist_y))
+ dataset = dataset.shuffle(2 * len(mnist_x))
+ dataset = dataset.apply(
+ data.experimental.map_and_batch(
+ map_func=_PreprocessFn,
+ batch_size=batch_size,
+ num_parallel_calls=8))
+ dataset = dataset.repeat(count=num_epochs)
+ iterator = dataset.make_one_shot_iterator()
+ features, labels = iterator.get_next()
+ return features, labels
+
+ def _ModelFn(features, labels, mode):
+ if is_training:
+ logits_out = self._BuildGraph(features)
+ else:
+ graph_def = self._GetGraphDef(use_trt, batch_size, model_dir)
+ logits_out = importer.import_graph_def(
+ graph_def,
+ input_map={INPUT_NODE_NAME: features},
+ return_elements=[OUTPUT_NODE_NAME + ':0'],
+ name='')[0]
+
+ loss = losses.sparse_softmax_cross_entropy(
+ labels=labels, logits=logits_out)
+ summary.scalar('loss', loss)
+
+ classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out')
+ accuracy = metrics.accuracy(
+ labels=labels, predictions=classes_out, name='acc_op')
+ summary.scalar('accuracy', accuracy[1])
+
+ if mode == ModeKeys.EVAL:
+ return EstimatorSpec(
+ mode, loss=loss, eval_metric_ops={'accuracy': accuracy})
+ elif mode == ModeKeys.TRAIN:
+ optimizer = AdamOptimizer(learning_rate=1e-2)
+ train_op = optimizer.minimize(loss, global_step=get_global_step())
+ return EstimatorSpec(mode, loss=loss, train_op=train_op)
+
+ config_proto = config_pb2.ConfigProto()
+ config_proto.gpu_options.allow_growth = True
+ estimator = Estimator(
+ model_fn=_ModelFn,
+ model_dir=model_dir if is_training else None,
+ config=RunConfig(session_config=config_proto))
+
+ if is_training:
+ estimator.train(_TrainInputFn)
+ results = estimator.evaluate(_EvalInputFn)
+ logging.info('accuracy: %s', str(results['accuracy']))
+ return results
+
+ # To generate the checkpoint, set a different model_dir and call self._Run()
+ # by setting is_training=True and num_epochs=1000, e.g.:
+ # model_dir = '/tmp/quantization_mnist'
+ # self._Run(
+ # is_training=True,
+ # use_trt=False,
+ # batch_size=128,
+ # num_epochs=100,
+ # model_dir=model_dir)
+ def testEval(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ model_dir = test.test_src_dir_path('contrib/tensorrt/test/testdata')
+
+ accuracy_tf_native = self._Run(
+ is_training=False,
+ use_trt=False,
+ batch_size=128,
+ num_epochs=None,
+ model_dir=model_dir)['accuracy']
+ logging.info('accuracy_tf_native: %f', accuracy_tf_native)
+ self.assertAllClose(accuracy_tf_native, 0.9662)
+
+ if trt_convert.get_linked_tensorrt_version()[0] < 5:
+ return
+
+ accuracy_tf_trt = self._Run(
+ is_training=False,
+ use_trt=True,
+ batch_size=128,
+ num_epochs=None,
+ model_dir=model_dir)['accuracy']
+ logging.info('accuracy_tf_trt: %f', accuracy_tf_trt)
+ self.assertAllClose(accuracy_tf_trt, 0.9677)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/quantization_test.py b/tensorflow/contrib/tensorrt/test/quantization_test.py
new file mode 100644
index 0000000..2835327
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/quantization_test.py
@@ -0,0 +1,144 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model script to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.tensorrt.python import trt_convert
+from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+def _GetParams(add_quantization_nodes, dtype=dtypes.float32):
+ input_name = "input"
+ input_dims = [8, 8]
+ output_name = "output"
+
+ def _Quantize(x, r):
+ if add_quantization_nodes:
+ x = gen_array_ops.fake_quant_with_min_max_vars(x, -r, r)
+ return x
+
+ g = ops.Graph()
+ with g.as_default():
+ x = array_ops.placeholder(
+ dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
+ x = _Quantize(x, 10.0)
+ x = x + 5
+ x = _Quantize(x, 15.0)
+ x = x - 5
+ x = _Quantize(x, 10.0)
+ x = x * 0.1
+ x = _Quantize(x, 1.0)
+ w = constant_op.constant(np.ones((8, 1)), dtype=dtypes.float32)
+ x = math_ops.matmul(x, w)
+ x = _Quantize(x, 10.0)
+ x = array_ops.identity(x, name=output_name)
+
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ output_names=[output_name],
+ expected_output_dims=[(8, 1)])
+
+
+class QuantizationMissingAllRangesTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing single segment with no quantization ranges."""
+ return _GetParams(add_quantization_nodes=False)
+
+ def ShouldRunTest(self, run_params):
+ if trt_convert.get_linked_tensorrt_version()[0] < 5:
+ return False
+ # Only test static engine mode, with or without calibration.
+ return (trt_test.IsQuantizationMode(run_params.precision_mode) and
+ not run_params.use_optimizer and not run_params.dynamic_engine)
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ if run_params.use_calibration:
+ # In static engine mode with calibration, it should build a calibration
+ # engine.
+ return ["my_trt_op_0"]
+ # In static engine mode without calibration, the engine building will fail
+ # since no quantization ranges are set, which results in no TRT nodes.
+ return []
+
+
+class QuantizationWithRangesTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing single segment with no quantization ranges."""
+ return _GetParams(add_quantization_nodes=True)
+
+ def ShouldRunTest(self, run_params):
+ if trt_convert.get_linked_tensorrt_version()[0] < 5:
+ return False
+ # Test static/dynamic engine with/without calibration.
+ return (trt_test.IsQuantizationMode(run_params.precision_mode) and
+ not run_params.use_optimizer)
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01
+
+
+class NonQuantizedPrecisionsWithRangesTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing single segment with no quantization ranges."""
+ return _GetParams(add_quantization_nodes=True)
+
+ def ShouldRunTest(self, run_params):
+ # Only test FP32/FP16 mode.
+ return not trt_test.IsQuantizationMode(run_params.precision_mode)
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ # The fake quant ops are not supported in FP32/FP16 mode, and will split the
+ # graph into three TRT segments.
+ return ["my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3"]
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-01
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/testdata/checkpoint b/tensorflow/contrib/tensorrt/test/testdata/checkpoint
new file mode 100644
index 0000000..a603e1a
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/testdata/checkpoint
@@ -0,0 +1,3 @@
+model_checkpoint_path: "model.ckpt-46900"
+all_model_checkpoint_paths: "model.ckpt-0"
+all_model_checkpoint_paths: "model.ckpt-46900"
diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001
new file mode 100644
index 0000000..88a998f
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index
new file mode 100644
index 0000000..5379765
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index
Binary files differ
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index c3cff28..80eb855 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -43,14 +43,15 @@
"gdef", "input_names", "input_dims", "output_names", "expected_output_dims"
])
-RunParams = namedtuple(
- "RunParams",
- ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"])
+RunParams = namedtuple("RunParams", [
+ "use_optimizer", "precision_mode", "dynamic_engine", "test_name",
+ "use_calibration"
+])
ConversionParams = namedtuple("ConversionParams", [
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
- "cached_engine_batch_sizes", "rewriter_config"
+ "cached_engine_batch_sizes", "rewriter_config", "use_calibration"
])
PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -69,6 +70,8 @@
def OptimizerDisabledRewriterConfig():
"""Returns a RewriterConfig with all default Grappler optimizers disabled."""
rewriter_config = rewriter_config_pb2.RewriterConfig()
+
+ # Turn off all default Grappler optimizers.
off = rewriter_config_pb2.RewriterConfig.OFF
rewriter_config.layout_optimizer = off
rewriter_config.constant_folding = off
@@ -85,6 +88,10 @@
rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
rewriter_config.pin_to_host_optimization = off
rewriter_config.auto_parallel.enable = False
+
+ # Run only once for each enabled optimizer.
+ rewriter_config.meta_optimizer_iterations = (
+ rewriter_config_pb2.RewriterConfig.ONE)
return rewriter_config
@@ -162,11 +169,15 @@
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,
cached_engine_batch_sizes=None,
- rewriter_config=None)
+ rewriter_config=None,
+ use_calibration=run_params.use_calibration)
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
- return True
+ # This setting combination requires quantization nodes to be present in
+ # order to build the engine.
+ return not (IsQuantizationMode(run_params.precision_mode) and
+ not run_params.use_calibration)
def VerifyRunForEngine(self, engine_name, graph_state, expect_run=True):
"""Verify the state of a particular engine after sess.run()."""
@@ -237,7 +248,8 @@
conversion_params.minimum_segment_size,
conversion_params.is_dynamic_op,
conversion_params.maximum_cached_engines,
- conversion_params.cached_engine_batch_sizes)
+ conversion_params.cached_engine_batch_sizes,
+ conversion_params.use_calibration)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
@@ -329,6 +341,7 @@
is_dynamic_op=conversion_params.is_dynamic_op,
maximum_cached_engines=conversion_params.maximum_cached_engines,
cached_engine_batch_sizes=conversion_params.cached_engine_batch_sizes,
+ use_calibration=conversion_params.use_calibration,
session_config=config_for_trt)
def _WriteGraph(self, run_params, gdef, graph_state):
@@ -428,10 +441,12 @@
is_dynamic_engine = not node.attr["static_engine"].b
self.assertEqual(run_params.dynamic_engine, is_dynamic_engine,
node.name)
+ self.assertEqual(node.attr["use_calibration"].b,
+ run_params.use_calibration, node.name)
has_calibration_data = len(node.attr["calibration_data"].s)
if (IsQuantizationMode(run_params.precision_mode) and
- graph_state == GraphState.INFERENCE):
+ run_params.use_calibration and graph_state == GraphState.INFERENCE):
self.assertTrue(has_calibration_data, node.name)
else:
self.assertFalse(has_calibration_data, node.name)
@@ -482,7 +497,8 @@
config_no_trt, GraphState.ORIGINAL)
# Run calibration if necessary.
- if IsQuantizationMode(run_params.precision_mode):
+ if (IsQuantizationMode(run_params.precision_mode) and
+ run_params.use_calibration):
calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE)
logging.info("Running calibration graph, config:\n%s", str(calib_config))
@@ -552,27 +568,38 @@
use_optimizer_options = [False, True]
dynamic_engine_options = [False, True]
- for (use_optimizer, precision_mode, dynamic_engine) in itertools.product(
- use_optimizer_options, PRECISION_MODES, dynamic_engine_options):
+ use_calibration_options = [False, True]
+ opts = itertools.product(use_optimizer_options, PRECISION_MODES,
+ dynamic_engine_options, use_calibration_options)
+ for (use_optimizer, precision_mode, dynamic_engine, use_calibration) in opts:
if IsQuantizationMode(precision_mode):
if use_optimizer:
# TODO(aaroey): if use_optimizer is True we need to get the inference
# graphdef using custom python wrapper class, which is not currently
# supported yet.
continue
- if not dynamic_engine:
+ if use_calibration and not dynamic_engine:
+ # Static engine with use_calibration=False will be static, so we want to
+ # test that. If use_calibration=True, only dynamic op is supported.
# TODO(aaroey): construction of static calibration engine is not
# supported yet.
continue
+ else:
+ if use_calibration:
+ # Don't calibrate in FP32 or FP16 mode
+ continue
conversion = "OptimizerConversion" if use_optimizer else "ToolConversion"
- engine_type = ("DynamicEngine" if dynamic_engine else "StaticEngine")
- test_name = "%s_%s_%s" % (conversion, precision_mode, engine_type)
+ engine_type = "DynamicEngine" if dynamic_engine else "StaticEngine"
+ calibration_type = "UseCalibration" if use_calibration else "NoCalibration"
+ test_name = "%s_%s_%s_%s" % (conversion, engine_type, precision_mode,
+ calibration_type)
run_params = RunParams(
use_optimizer=use_optimizer,
precision_mode=precision_mode,
dynamic_engine=dynamic_engine,
- test_name=test_name)
+ test_name=test_name,
+ use_calibration=use_calibration)
setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index a0a9cb3..9992740 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -14,6 +14,7 @@
package(
default_visibility = [
"//cloud/vmm/testing/tests/tpu:__subpackages__",
+ "//knowledge/cerebra/sense/im2query:__subpackages__",
"//learning/brain:__subpackages__",
"//learning/deepmind:__subpackages__",
"//medical/pathology:__subpackages__",
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py
index c694e9c..d61c824 100644
--- a/tensorflow/contrib/tpu/python/tpu/datasets.py
+++ b/tensorflow/contrib/tpu/python/tpu/datasets.py
@@ -133,7 +133,7 @@
with ops.device('/job:%s' % file_reader_job):
if isinstance(files, str):
source_dataset = dataset_ops.Dataset.list_files(files)
- elif isinstance(files, dataset_ops.Dataset):
+ elif isinstance(files, dataset_ops.DatasetV2):
source_dataset = files
else:
raise ValueError('files was not a string or a dataset: %s' % files)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 73753cd..cf3b2e6 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -81,6 +81,7 @@
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import training_arrays
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.layers import embeddings
@@ -438,7 +439,7 @@
self._default_placeholder = array_ops.placeholder
self._default_name_scope = ops.name_scope
- self._default_make_variable = base_layer.make_variable
+ self._default_make_variable = base_layer_utils.make_variable
self._default_random_normal = random_ops.random_normal
self._default_qr = gen_linalg_ops.qr
@@ -486,14 +487,14 @@
gen_linalg_ops.qr = qr
ops.name_scope = _name_scope
- base_layer.make_variable = variable_scope.get_variable
+ base_layer_utils.make_variable = variable_scope.get_variable
logging.info('Overriding default placeholder.')
return
def __exit__(self, exc_type, exc_val, exc_tb):
array_ops.placeholder = self._default_placeholder
ops.name_scope = self._default_name_scope
- base_layer.make_variable = self._default_make_variable
+ base_layer_utils.make_variable = self._default_make_variable
random_ops.random_normal = self._default_random_normal
gen_linalg_ops.qr = self._default_qr
@@ -769,7 +770,7 @@
def _verify_dataset_shape(self, dataset):
"""Verifies a dataset is of an appropriate shape for TPUs."""
- if not isinstance(dataset, dataset_ops.Dataset):
+ if not isinstance(dataset, dataset_ops.DatasetV2):
raise ValueError('The function passed as the `x` parameter did not '
'return a `tf.data.Dataset`.')
if not isinstance(dataset.output_classes, tuple):
@@ -1465,7 +1466,7 @@
assert not self._numpy_to_infeed_manager_list # Ensure empty.
infeed_managers = [] # Managers to clean up at the end of the fit call.
- if isinstance(x, dataset_ops.Dataset):
+ if isinstance(x, dataset_ops.DatasetV2):
# TODO(b/111413240): Support taking a tf.data.Dataset directly.
raise ValueError(
'Taking a Dataset directly is not yet supported. Please '
@@ -1491,7 +1492,7 @@
y = infeed_manager.dummy_y
infeed_managers.append((x, infeed_manager))
- if isinstance(validation_data, dataset_ops.Dataset):
+ if isinstance(validation_data, dataset_ops.DatasetV2):
# TODO(b/111413240): Support taking a tf.data.Dataset directly.
raise ValueError(
'Taking a Dataset directly is not yet supported. Please '
@@ -1550,7 +1551,7 @@
with _tpu_session_context():
# Managers to clean up at the end of the evaluate call.
infeed_managers = []
- if isinstance(x, dataset_ops.Dataset):
+ if isinstance(x, dataset_ops.DatasetV2):
# TODO(b/111413240): Support taking a tf.data.Dataset directly.
raise ValueError(
'Taking a Dataset directly is not yet supported. Please '
@@ -1922,7 +1923,7 @@
if validation_data:
if (isinstance(validation_data, iterator_ops.Iterator) or
isinstance(validation_data, iterator_ops.EagerIterator) or
- isinstance(validation_data, dataset_ops.Dataset)):
+ isinstance(validation_data, dataset_ops.DatasetV2)):
raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator '
'for validation_data. Please instead pass a function '
'that returns a `tf.data.Dataset`.')
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index a023612..def57da 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -1111,7 +1111,7 @@
Raises:
RuntimeError: if validation failed.
"""
- if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]):
+ if not any(x.type == "GuaranteeConst" for x in graph.get_operations()):
raise RuntimeError(
"No GuaranteeConst ops found in the graph after running "
"tpu.rewrite_for_inference(...). Please check that you are using "
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py
index 3fe8964..ccba8a4 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py
@@ -1069,17 +1069,14 @@
'As TPU embedding is not optimized for small tables, '
'please consider other ways for this embedding lookup.')
- slicing = [num_hosts, 1]
-
- # TODO(shizhiw): deprecated, use tf.get_variable()?
- return partitioned_variables.create_partitioned_variables(
- name=name,
- slicing=slicing,
+ return list(variable_scope.get_variable(
+ name,
shape=(vocabulary_size, embedding_dimension),
+ partitioner=partitioned_variables.fixed_size_partitioner(num_hosts),
dtype=dtypes.float32,
initializer=initializer,
collections=collections,
- trainable=False)
+ trainable=False))
@ops.RegisterGradient('TPUEmbeddingActivations')
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 932367f..7171587 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -2169,7 +2169,6 @@
builder,
input_receiver_fn_map,
checkpoint_path,
- strip_default_attrs,
save_variables=True,
mode=model_fn_lib.ModeKeys.PREDICT,
export_tags=None,
@@ -2184,7 +2183,6 @@
builder,
input_receiver_fn_map,
checkpoint_path,
- strip_default_attrs,
save_variables,
mode=mode,
export_tags=export_tags,
@@ -2201,7 +2199,6 @@
builder,
input_receiver_fn_map,
checkpoint_path,
- strip_default_attrs,
save_variables=False,
mode=mode,
export_tags=export_tags,
@@ -2783,7 +2780,7 @@
elif isinstance(export_output, export_output_lib.RegressionOutput):
return [export_output.value]
elif isinstance(export_output, export_output_lib.PredictOutput):
- return export_output.outputs.values()
+ return list(export_output.outputs.values())
else:
raise ValueError(
'`export_output` must be have type `ClassificationOutput`, '
@@ -3059,7 +3056,7 @@
@staticmethod
def from_input_fn(return_values):
"""Returns an `_Inputs` instance according to `input_fn` return value."""
- if isinstance(return_values, dataset_ops.Dataset):
+ if isinstance(return_values, dataset_ops.DatasetV2):
dataset = return_values
return _Inputs(dataset=dataset)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
index cf36103..d5957b7 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py
@@ -833,24 +833,30 @@
dims = np.array(dims)
self._check_input_partition_dims(tensor, dims)
output = [tensor]
- divds, remainders = np.divmod(np.array(tensor.shape.as_list()), dims)
- for axis, (divd, remainder, dim) in enumerate(
- np.dstack((divds, remainders, dims))[0]):
+ shape_list = np.array(tensor.shape.as_list())
+ quotients, remainders = np.divmod(shape_list, dims)
+ for axis, (quotient, remainder, dim, original_size) in enumerate(
+ zip(quotients, remainders, dims, shape_list)):
if dim <= 1:
continue
if remainder > 0:
# For each dimension, when it cannot be evenly partitioned, XLA assumes
- # the size of last parts are smaller by 1. E.g. 2D tensor with shape
- # (5, 14) and dims are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14]
- # => [[(3, 3), (3, 3), (2, 3), (2, 3)],
- # [(2, 3), (2, 3), (2, 2), (2, 2)]]
- output = [
- array_ops.split(
- x,
- num_or_size_splits=[divd + 1] * remainder +
- [divd] * (dim - remainder),
- axis=axis) for x in output
- ]
+ # tensors are partitioned in a greedy manner by using
+ # ceil_ratio(size/dim) first. E.g. 2D tensor with shape (5, 14) and dims
+ # are (2, 4). Since 5 % 2 = 1 and 14 % 4 = 2, [5, 14] =>
+ # [[(3, 4), (3, 4), (2, 4), (2, 2)],
+ # [(2, 4), (2, 4), (2, 4), (2, 2)]]
+ ceil_ratio = quotient + 1
+ num_full_slots, left_over = np.divmod(original_size, ceil_ratio)
+ num_or_size_splits = [ceil_ratio] * num_full_slots + [left_over]
+ if len(num_or_size_splits) < dim:
+ num_or_size_splits += [0] * (dim - len(num_or_size_splits))
+ new_output = []
+ for x in output:
+ new_output.append(
+ array_ops.split(
+ x, num_or_size_splits=num_or_size_splits, axis=axis))
+ output = new_output
else:
output = [array_ops.split(x, dim, axis=axis) for x in output]
output = nest.flatten(output)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 2a8c271..9cd0eda 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -95,7 +95,7 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_tests_gpu")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule")
-load("//tensorflow:tensorflow.bzl", "if_not_tx2_llvm_or_windows_cuda")
+load("//tensorflow:tensorflow.bzl", "if_nccl")
load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test")
# For platform specific build config
@@ -112,6 +112,7 @@
"tf_additional_device_tracer_test_flags",
"tf_additional_gdr_lib_defines",
"tf_additional_human_readable_json_deps",
+ "tf_additional_logger_deps",
"tf_additional_lib_defines",
"tf_additional_lib_deps",
"tf_additional_lib_hdrs",
@@ -443,6 +444,18 @@
] + tf_additional_human_readable_json_deps(),
)
+cc_library(
+ name = "logger",
+ srcs = tf_platform_srcs(["logger.cc"]),
+ hdrs = ["platform/logger.h"] + tf_platform_hdrs(["logger.h"]),
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lib",
+ ":lib_internal",
+ ] + tf_additional_logger_deps(),
+)
+
filegroup(
name = "platform_env_hdrs",
srcs = [
@@ -915,6 +928,7 @@
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
+ "util/tensor_ops_util.h",
"util/tensor_slice_reader.h",
"util/tensor_slice_reader_cache.h",
"util/tensor_slice_writer.h",
@@ -1403,9 +1417,7 @@
"//tensorflow/core/kernels:summary_kernels",
"//tensorflow/core/kernels:training_ops",
"//tensorflow/core/kernels:word2vec_kernels",
- ] + tf_additional_cloud_kernel_deps() + if_not_tx2_llvm_or_windows_cuda([
- "//tensorflow/core/kernels:nccl_kernels",
- ]) + if_not_windows([
+ ] + tf_additional_cloud_kernel_deps() + if_not_windows([
"//tensorflow/core/kernels:fact_op",
"//tensorflow/core/kernels:array_not_windows",
"//tensorflow/core/kernels:math_not_windows",
@@ -1430,6 +1442,8 @@
]) + if_cuda([
"//tensorflow/core/grappler/optimizers:gpu_swapping_kernels",
"//tensorflow/core/grappler/optimizers:gpu_swapping_ops",
+ ]) + if_nccl([
+ "//tensorflow/core/kernels:nccl_kernels",
]),
)
@@ -1594,6 +1608,8 @@
"util/stats_calculator.*",
"util/reporter.*",
"platform/**/cuda_libdevice_path.*",
+ "platform/**/logger.cc",
+ "platform/**/logger.h",
"platform/default/test_benchmark.*",
"platform/cuda.h",
"platform/google/**/*",
@@ -2206,6 +2222,7 @@
"platform/**/env_time.cc",
"platform/**/cuda_libdevice_path.cc",
"platform/**/device_tracer.cc",
+ "platform/**/logger.cc",
"platform/**/logging.cc",
"platform/**/human_readable_json.cc",
"platform/abi.cc",
@@ -2218,6 +2235,7 @@
"platform/**/stream_executor.h",
"platform/**/env_time.cc",
"platform/**/device_tracer.cc",
+ "platform/**/logger.cc",
"platform/**/logging.cc",
"platform/**/human_readable_json.cc",
"platform/abi.cc",
@@ -2963,6 +2981,7 @@
":lib_internal",
":proto_text",
":protos_all_cc",
+ "@com_google_absl//absl/memory",
"//third_party/eigen3",
"//tensorflow/core/grappler:grappler_item",
] + mkl_deps(),
@@ -3023,6 +3042,15 @@
)
tf_cuda_library(
+ name = "metrics",
+ srcs = ["common_runtime/metrics.cc"],
+ hdrs = ["common_runtime/metrics.h"],
+ deps = [
+ ":lib",
+ ],
+)
+
+tf_cuda_library(
name = "direct_session_internal",
srcs = ["common_runtime/direct_session.cc"],
hdrs = [
@@ -3038,6 +3066,7 @@
":graph",
":lib",
":lib_internal",
+ ":metrics",
":proto_text",
":protos_all_cc",
"//tensorflow/core/debug:debug_graph_utils",
@@ -3401,6 +3430,7 @@
"platform/profile_utils/cpu_utils_test.cc",
"platform/stacktrace_handler_test.cc",
"platform/subprocess_test.cc",
+ "platform/vmodule_benchmark_test.cc",
],
deps = [
":lib",
@@ -3415,6 +3445,20 @@
)
tf_cc_test(
+ name = "vmodule_test",
+ srcs = ["platform/vmodule_test.cc"],
+ tags = ["optonly"],
+ deps = [
+ ":lib",
+ ":lib_internal",
+ ":lib_test_internal",
+ ":protos_all_cc",
+ ":test",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cc_test(
name = "lib_random_random_distributions_test",
srcs = ["lib/random/random_distributions_test.cc"],
tags = ["optonly"],
@@ -3816,6 +3860,7 @@
":test",
":test_main",
":testlib",
+ "@com_google_absl//absl/memory",
],
)
@@ -3844,6 +3889,7 @@
":test",
":test_main",
":testlib",
+ "@com_google_absl//absl/memory",
],
)
@@ -4411,6 +4457,7 @@
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:shape_ops",
"//third_party/eigen3",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMatchingFilesDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMatchingFilesDataset.pbtxt
new file mode 100644
index 0000000..993a798
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMatchingFilesDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalMatchingFilesDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMaxIntraOpParallelismDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaxIntraOpParallelismDataset.pbtxt
new file mode 100644
index 0000000..a18aa37
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaxIntraOpParallelismDataset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalMaxIntraOpParallelismDataset"
+ in_arg {
+ name: "max_intra_op_parallelism"
+ description: <<END
+Identifies the maximum intra-op parallelism to use.
+END
+ }
+ summary: <<END
+Creates a dataset that overrides the maximum intra-op parallelism.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalPrivateThreadPoolDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalPrivateThreadPoolDataset.pbtxt
new file mode 100644
index 0000000..eaa49b7
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalPrivateThreadPoolDataset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalPrivateThreadPoolDataset"
+ in_arg {
+ name: "num_threads"
+ description: <<END
+Identifies the number of threads to use for the private threadpool.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MatchingFilesDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatchingFilesDataset.pbtxt
deleted file mode 100644
index ab2a331..0000000
--- a/tensorflow/core/api_def/base_api/api_def_MatchingFilesDataset.pbtxt
+++ /dev/null
@@ -1,4 +0,0 @@
-op {
- graph_op_name: "MatchingFilesDataset"
- visibility: HIDDEN
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt
index c431425..dff7c87 100644
--- a/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_QuantizeAndDequantizeV2.pbtxt
@@ -41,6 +41,19 @@
Whether the range is given or should be determined from the `input` tensor.
END
}
+ attr {
+ name: "round_mode"
+ description: <<END
+The 'round_mode' attribute controls which rounding tie-breaking algorithm is
+used when rounding float values to their quantized equivalents. The following
+rounding modes are currently supported:
+
+* HALF_TO_EVEN: this is the default round_mode.
+* HALF_UP: round towards positive. In this mode 7.5 rounds up to 8 and -7.5
+ rounds up to -7.
+
+END
+ }
summary: "Quantizes then dequantizes a tensor."
description: <<END
This op simulates the precision loss from the quantized forward pass by:
@@ -93,7 +106,7 @@
output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor.
-The above round function uses half to even rounding.
+The above round function rounds the value based on the given round_mode.
END
}
diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
index 801dfbc..94ffc7c 100644
--- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
@@ -1,7 +1,9 @@
op {
graph_op_name: "BatchToSpaceND"
+ deprecation_message: "use batch_to_space"
endpoint {
name: "batch_to_space_nd"
+ deprecation_version: 2
}
endpoint {
name: "manip.batch_to_space_nd"
diff --git a/tensorflow/core/api_def/python_api/api_def_CropAndResize.pbtxt b/tensorflow/core/api_def/python_api/api_def_CropAndResize.pbtxt
index ce65f81..2559a6c8 100644
--- a/tensorflow/core/api_def/python_api/api_def_CropAndResize.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_CropAndResize.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "CropAndResize"
- endpoint {
- name: "image.crop_and_resize"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeAndCropJpeg.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeAndCropJpeg.pbtxt
index fbe9c88..2c3857c 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeAndCropJpeg.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeAndCropJpeg.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "DecodeAndCropJpeg"
- endpoint {
- name: "image.decode_and_crop_jpeg"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeBmp.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeBmp.pbtxt
index 573d83f..ffe19ca 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeBmp.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeBmp.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "DecodeBmp"
- endpoint {
- name: "image.decode_bmp"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeGif.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeGif.pbtxt
index eed64df..ff68b99 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeGif.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeGif.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "DecodeGif"
- endpoint {
- name: "image.decode_gif"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeJpeg.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeJpeg.pbtxt
index 994bc4e..97d262a 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeJpeg.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeJpeg.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "DecodeJpeg"
- endpoint {
- name: "image.decode_jpeg"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodePng.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodePng.pbtxt
index 309eec5..3b9290a 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodePng.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodePng.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "DecodePng"
- endpoint {
- name: "image.decode_png"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Dilation2D.pbtxt b/tensorflow/core/api_def/python_api/api_def_Dilation2D.pbtxt
index 6d73ecf..1bd83d9 100644
--- a/tensorflow/core/api_def/python_api/api_def_Dilation2D.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Dilation2D.pbtxt
@@ -2,5 +2,6 @@
graph_op_name: "Dilation2D"
endpoint {
name: "nn.dilation2d"
+ deprecation_version: 2
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_EncodeJpeg.pbtxt b/tensorflow/core/api_def/python_api/api_def_EncodeJpeg.pbtxt
index 5c31e9d..054ffb9 100644
--- a/tensorflow/core/api_def/python_api/api_def_EncodeJpeg.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_EncodeJpeg.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "EncodeJpeg"
- endpoint {
- name: "image.encode_jpeg"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
index 0bd8b1c..17921de 100644
--- a/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ExtractImagePatches.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "ExtractImagePatches"
- endpoint {
- name: "image.extract_image_patches"
- }
- endpoint {
- name: "extract_image_patches"
- deprecation_version: 2
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ExtractJpegShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_ExtractJpegShape.pbtxt
index 6849a6d..a57955c 100644
--- a/tensorflow/core/api_def/python_api/api_def_ExtractJpegShape.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ExtractJpegShape.pbtxt
@@ -1,6 +1,4 @@
op {
graph_op_name: "ExtractJpegShape"
- endpoint {
- name: "image.extract_jpeg_shape"
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_MaxPoolWithArgmax.pbtxt b/tensorflow/core/api_def/python_api/api_def_MaxPoolWithArgmax.pbtxt
index 7d8abca..13a1a0b 100644
--- a/tensorflow/core/api_def/python_api/api_def_MaxPoolWithArgmax.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MaxPoolWithArgmax.pbtxt
@@ -2,5 +2,6 @@
graph_op_name: "MaxPoolWithArgmax"
endpoint {
name: "nn.max_pool_with_argmax"
+ deprecation_version: 2
}
}
diff --git a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc
index 822d006..c4bc1a6 100644
--- a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc
+++ b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc
@@ -74,8 +74,7 @@
Status rewriteNode(Node* n, Graph* g) {
AttrSlice n_attrs = n->attrs();
- auto base_make_node = [n, g, &n_attrs](const string& op,
- const string& name) {
+ auto base_make_node = [n, &n_attrs](const string& op, const string& name) {
NodeBuilder node_builder(name, op);
// The pieces of AccumulateNV2 should all be on the same node.
@@ -86,7 +85,7 @@
}
return node_builder;
};
- auto make_node = [n, g, &n_attrs, &base_make_node](string op) {
+ auto make_node = [n, g, &base_make_node](string op) {
return base_make_node(
op, g->NewName(strings::StrCat(n->name(), "/Internal")));
};
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
index 91994c5..f3d86aa 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
@@ -38,8 +38,9 @@
auto* device_count = options.config.mutable_device_count();
string task_name = "/job:localhost/replica:0/task:0";
device_count->insert({"CPU", NUM_DEVS});
- TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
- device_mgr_.reset(new DeviceMgr(devices_));
+ std::vector<std::unique_ptr<Device>> devices;
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
+ device_mgr_.reset(new DeviceMgr(std::move(devices)));
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
std::unique_ptr<ParamResolverInterface> prl(
@@ -50,7 +51,6 @@
}
std::unique_ptr<CollectiveExecutorMgr> cme_;
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
};
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 624d3f2..a8e3f4c 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -696,7 +696,7 @@
if (ir->source_rank >= 0) {
ir->status = errors::Internal("Instance ", cp->instance.instance_key,
" already has source ", ir->source_rank,
- ", recevied second claim from ",
+ ", received second claim from ",
cp->default_rank);
} else {
ir->source_rank = cp->default_rank;
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index 9a501b3..94d889c 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -37,8 +37,9 @@
string task_name = "/job:localhost/replica:0/task:0";
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS});
- TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
- device_mgr_.reset(new DeviceMgr(devices_));
+ std::vector<std::unique_ptr<Device>> devices;
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
+ device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
task_name));
@@ -73,7 +74,6 @@
}
}
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
std::unique_ptr<CollectiveParamResolverLocal> prl_;
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index a931fe6..4263f3a 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -42,8 +42,9 @@
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS});
- TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices_));
- device_mgr_.reset(new DeviceMgr(devices_));
+ std::vector<std::unique_ptr<Device>> devices;
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
+ device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
kTaskName));
@@ -51,7 +52,6 @@
kStepId));
}
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
std::unique_ptr<CollectiveParamResolverLocal> prl_;
diff --git a/tensorflow/core/common_runtime/device_factory.cc b/tensorflow/core/common_runtime/device_factory.cc
index b949001..0fad13f 100644
--- a/tensorflow/core/common_runtime/device_factory.cc
+++ b/tensorflow/core/common_runtime/device_factory.cc
@@ -20,6 +20,7 @@
#include <unordered_map>
#include <vector>
+#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -89,9 +90,9 @@
return it->second.factory.get();
}
-Status DeviceFactory::AddDevices(const SessionOptions& options,
- const string& name_prefix,
- std::vector<Device*>* devices) {
+Status DeviceFactory::AddDevices(
+ const SessionOptions& options, const string& name_prefix,
+ std::vector<std::unique_ptr<Device>>* devices) {
// CPU first. A CPU device is required.
auto cpu_factory = GetFactory("CPU");
if (!cpu_factory) {
@@ -116,16 +117,16 @@
return Status::OK();
}
-Device* DeviceFactory::NewDevice(const string& type,
- const SessionOptions& options,
- const string& name_prefix) {
+std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type,
+ const SessionOptions& options,
+ const string& name_prefix) {
auto device_factory = GetFactory(type);
if (!device_factory) {
return nullptr;
}
SessionOptions opt = options;
(*opt.config.mutable_device_count())[type] = 1;
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
int expected_num_devices = 1;
auto iter = options.config.device_count().find(type);
@@ -133,7 +134,7 @@
expected_num_devices = iter->second;
}
DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
- return devices[0];
+ return std::move(devices[0]);
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h
index db50226..b3cd7ad 100644
--- a/tensorflow/core/common_runtime/device_factory.h
+++ b/tensorflow/core/common_runtime/device_factory.h
@@ -40,18 +40,19 @@
// CPU devices are added first.
static Status AddDevices(const SessionOptions& options,
const string& name_prefix,
- std::vector<Device*>* devices);
+ std::vector<std::unique_ptr<Device>>* devices);
// Helper for tests. Create a single device of type "type". The
// returned device is always numbered zero, so if creating multiple
// devices of the same type, supply distinct name_prefix arguments.
- static Device* NewDevice(const string& type, const SessionOptions& options,
- const string& name_prefix);
+ static std::unique_ptr<Device> NewDevice(const string& type,
+ const SessionOptions& options,
+ const string& name_prefix);
// Most clients should call AddDevices() instead.
- virtual Status CreateDevices(const SessionOptions& options,
- const string& name_prefix,
- std::vector<Device*>* devices) = 0;
+ virtual Status CreateDevices(
+ const SessionOptions& options, const string& name_prefix,
+ std::vector<std::unique_ptr<Device>>* devices) = 0;
// Return the device priority number for a "device_type" string.
//
diff --git a/tensorflow/core/common_runtime/device_mgr.cc b/tensorflow/core/common_runtime/device_mgr.cc
index 470abc1..1f7d7c4 100644
--- a/tensorflow/core/common_runtime/device_mgr.cc
+++ b/tensorflow/core/common_runtime/device_mgr.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
@@ -24,32 +25,32 @@
namespace tensorflow {
-DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
- : name_backing_store_(128) {
- for (Device* d : devices) {
+DeviceMgr::DeviceMgr(std::vector<std::unique_ptr<Device>> devices)
+ : devices_(std::move(devices)), name_backing_store_(128) {
+ for (auto& d : devices_) {
CHECK(d->device_mgr_ == nullptr);
d->device_mgr_ = this;
- devices_.push_back(d);
-
// Register under the (1) full name and (2) canonical name.
for (const string& name :
DeviceNameUtils::GetNamesForDeviceMappings(d->parsed_name())) {
- device_map_[CopyToBackingStore(name)] = d;
+ device_map_[CopyToBackingStore(name)] = d.get();
}
// Register under the (3) local name and (4) legacy local name.
for (const string& name :
DeviceNameUtils::GetLocalNamesForDeviceMappings(d->parsed_name())) {
- device_map_[CopyToBackingStore(name)] = d;
+ device_map_[CopyToBackingStore(name)] = d.get();
}
device_type_counts_[d->device_type()]++;
}
}
-DeviceMgr::~DeviceMgr() {
- // TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
- for (Device* p : devices_) delete p;
-}
+DeviceMgr::DeviceMgr(std::unique_ptr<Device> device)
+ : DeviceMgr([&device] {
+ std::vector<std::unique_ptr<Device>> vector;
+ vector.push_back(std::move(device));
+ return vector;
+ }()) {}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
size_t n = s.size();
@@ -61,18 +62,22 @@
void DeviceMgr::ListDeviceAttributes(
std::vector<DeviceAttributes>* devices) const {
devices->reserve(devices_.size());
- for (Device* dev : devices_) {
+ for (const auto& dev : devices_) {
devices->emplace_back(dev->attributes());
}
}
std::vector<Device*> DeviceMgr::ListDevices() const {
- return std::vector<Device*>(devices_.begin(), devices_.end());
+ std::vector<Device*> devices(devices_.size());
+ for (size_t i = 0; i < devices_.size(); ++i) {
+ devices[i] = devices_[i].get();
+ }
+ return devices;
}
string DeviceMgr::DebugString() const {
string out;
- for (Device* dev : devices_) {
+ for (const auto& dev : devices_) {
strings::StrAppend(&out, dev->name(), "\n");
}
return out;
@@ -80,7 +85,7 @@
string DeviceMgr::DeviceMappingString() const {
string out;
- for (Device* dev : devices_) {
+ for (const auto& dev : devices_) {
if (!dev->attributes().physical_device_desc().empty()) {
strings::StrAppend(&out, dev->name(), " -> ",
dev->attributes().physical_device_desc(), "\n");
@@ -107,7 +112,7 @@
void DeviceMgr::ClearContainers(gtl::ArraySlice<string> containers) const {
Status s;
- for (Device* dev : devices_) {
+ for (const auto& dev : devices_) {
if (containers.empty()) {
s.Update(dev->resource_manager()->Cleanup(
dev->resource_manager()->default_container()));
diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h
index c1ff10d..bf86946 100644
--- a/tensorflow/core/common_runtime/device_mgr.h
+++ b/tensorflow/core/common_runtime/device_mgr.h
@@ -16,6 +16,7 @@
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
+#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
@@ -34,15 +35,17 @@
class DeviceMgr {
public:
- // Takes ownership of each device in 'devices'.
+ // Constructs a DeviceMgr from a list of devices.
// TODO(zhifengc): Other initialization information.
- // TODO(b/37437134): Use std::unique_ptr's to track ownership.
- explicit DeviceMgr(const std::vector<Device*>& devices);
- ~DeviceMgr();
+ explicit DeviceMgr(std::vector<std::unique_ptr<Device>> devices);
+
+ // Constructs a DeviceMgr managing a single device.
+ explicit DeviceMgr(std::unique_ptr<Device> device);
// Returns attributes of all devices.
void ListDeviceAttributes(std::vector<DeviceAttributes>* devices) const;
+ // Returns raw pointers to the underlying devices.
std::vector<Device*> ListDevices() const;
// Returns a string listing all devices.
@@ -62,9 +65,7 @@
int NumDeviceType(const string& type) const;
private:
- // TODO(b/37437134): Use std::unique_ptr's to track ownership.
- typedef gtl::InlinedVector<Device*, 8> DeviceVec;
- DeviceVec devices_;
+ const std::vector<std::unique_ptr<Device>> devices_;
StringPiece CopyToBackingStore(StringPiece s);
diff --git a/tensorflow/core/common_runtime/device_resolver_local_test.cc b/tensorflow/core/common_runtime/device_resolver_local_test.cc
index f5a6471..54f1119 100644
--- a/tensorflow/core/common_runtime/device_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/device_resolver_local_test.cc
@@ -36,12 +36,12 @@
string task_name = "/job:localhost/replica:0/task:0";
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS});
- TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
- device_mgr_.reset(new DeviceMgr(devices_));
+ std::vector<std::unique_ptr<Device>> devices;
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
+ device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
}
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
};
diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc
index fd9c422..6a8c3d1 100644
--- a/tensorflow/core/common_runtime/device_set_test.cc
+++ b/tensorflow/core/common_runtime/device_set_test.cc
@@ -57,7 +57,7 @@
class DummyFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override {
+ std::vector<std::unique_ptr<Device>>* devices) override {
return Status::OK();
}
};
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 40b7071..0434ca4 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -30,6 +30,7 @@
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
+#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
@@ -155,12 +156,12 @@
if (options.config.graph_options().build_cost_model() > 0) {
EnableCPUAllocatorFullStats(true);
}
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
DirectSession* session =
- new DirectSession(options, new DeviceMgr(devices), this);
+ new DirectSession(options, new DeviceMgr(std::move(devices)), this);
{
mutex_lock l(sessions_lock_);
sessions_.push_back(session);
@@ -462,6 +463,7 @@
CallFrameInterface* call_frame,
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata) {
+ const uint64 start_time_usecs = Env::Default()->NowMicros();
string session_id_meta = strings::StrCat("SessionRun #id=", step_id, "#");
tracing::ScopedActivity activity(session_id_meta);
@@ -716,6 +718,7 @@
exec_and_lib.graph->ToGraphDef(partition_graph_def);
}
}
+ UpdateGraphExecTime(Env::Default()->NowMicros() - start_time_usecs);
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index a7b618c..86890ba 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -181,6 +181,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 4de807b..51109f8 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -206,6 +206,8 @@
bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
+ tensorflow::Env* TFEnv() const { return env_; }
+
private:
void InitDeviceMapAndAsync();
Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef);
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
index 948bdbc..3ffed3c 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
@@ -18,6 +18,7 @@
#include <memory>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
@@ -37,12 +38,13 @@
class TestEnv {
public:
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
- Device* device =
- DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
- device_mgr_.reset(new DeviceMgr({device}));
- flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
- device, TF_GRAPH_DEF_VERSION,
- &flib_def_, nullptr, {}, nullptr);
+ std::vector<std::unique_ptr<Device>> devices;
+ devices.push_back(
+ DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"));
+ device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
+ flib_runtime_ = NewFunctionLibraryRuntime(
+ device_mgr_.get(), Env::Default(), device_mgr_->ListDevices()[0],
+ TF_GRAPH_DEF_VERSION, &flib_def_, nullptr, {}, nullptr);
}
FunctionLibraryRuntime* function_library_runtime() const {
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 77b249c..6b3284b 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1239,7 +1239,6 @@
// Step-local container.
ScopedStepContainer* step_container_;
StepStatsCollectorInterface* const stats_collector_;
- const tracing::TraceCollector* const trace_collector_;
const tracing::EventCollector* const event_collector_;
Context context_;
@@ -1366,7 +1365,6 @@
tensor_store_(args.tensor_store),
step_container_(args.step_container),
stats_collector_(args.stats_collector),
- trace_collector_(tracing::GetTraceCollector()),
event_collector_(
tracing::GetEventCollector(tracing::EventCategory::kCompute)),
context_(ContextKind::kThread),
@@ -1565,7 +1563,6 @@
// Returns true if `item` might be traced by the given trace and event
// collectors. Returns false only if `item` definitely will not be traced.
bool MightTrace(const NodeItem& item,
- const tracing::TraceCollector* trace_collector,
const tracing::EventCollector* event_collector,
bool using_annotations) {
// Tracing will only be enabled if either `event_collector` is non null,
@@ -1578,6 +1575,7 @@
if (event_collector != nullptr) {
return true;
}
+ auto* trace_collector = tracing::GetTraceCollector();
if (trace_collector) {
if (using_annotations) {
return trace_collector->IsEnabledForAnnotations();
@@ -1762,9 +1760,8 @@
OpKernelContext ctx(¶ms, item.num_outputs);
nodestats::SetOpStart(stats);
- if (TF_PREDICT_FALSE(MightTrace(item, trace_collector_,
- event_collector_,
- trace_using_annotations_))) {
+ if (TF_PREDICT_FALSE(
+ MightTrace(item, event_collector_, trace_using_annotations_))) {
const string& op_name = op_kernel->name();
tracing::ScopedRegion region(tracing::EventCategory::kCompute,
op_name);
@@ -2048,13 +2045,14 @@
TaggedNodeSeq* ready) {
auto activity_handle =
[&]() -> std::unique_ptr<tracing::TraceCollector::Handle> {
- if (TF_PREDICT_FALSE(trace_collector_ != nullptr &&
- trace_collector_->IsEnabledForActivities(
+ auto* trace_collector = tracing::GetTraceCollector();
+ if (TF_PREDICT_FALSE(trace_collector != nullptr &&
+ trace_collector->IsEnabledForActivities(
false /* is_expensive */))) {
const string& op_name = item->kernel->name();
// Intentionally using ExecutorPropagateOutputs as the first key so that
// users are aware that it's not the op invocation.
- return trace_collector_->CreateActivityHandle(
+ return trace_collector->CreateActivityHandle(
"ExecutorPropagateOutputs",
strings::StrCat(op_name, "#id=", step_id_, "#"),
false /* is_expensive */);
diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc
index 7697103..c311b25 100644
--- a/tensorflow/core/common_runtime/executor_test.cc
+++ b/tensorflow/core/common_runtime/executor_test.cc
@@ -53,17 +53,17 @@
// when the test completes.
CHECK(rendez_->Unref());
delete exec_;
- delete device_;
}
// Resets executor_ with a new executor based on a graph 'gdef'.
void Create(std::unique_ptr<const Graph> graph) {
const int version = graph->versions().producer();
LocalExecutorParams params;
- params.device = device_;
+ params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) {
- return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+ return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
+ kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
@@ -83,7 +83,7 @@
}
thread::ThreadPool* thread_pool_ = nullptr;
- Device* device_ = nullptr;
+ std::unique_ptr<Device> device_;
Executor* exec_ = nullptr;
StepStatsCollector step_stats_collector_;
StepStats step_stats_;
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 13c189f..3b4c976 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -18,6 +18,7 @@
#include <atomic>
#include <utility>
+#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/cc/ops/array_ops_internal.h"
@@ -147,14 +148,15 @@
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices_));
+ options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
- device_mgr_.reset(new DeviceMgr(devices_));
+ device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool, nullptr /* cluster_flr */));
@@ -358,7 +360,6 @@
FunctionLibraryRuntime* flr0_;
FunctionLibraryRuntime* flr1_;
FunctionLibraryRuntime* flr2_;
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc
index 655a68c..1b80373 100644
--- a/tensorflow/core/common_runtime/function_threadpool_test.cc
+++ b/tensorflow/core/common_runtime/function_threadpool_test.cc
@@ -54,21 +54,19 @@
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 3});
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(
- options, "/job:localhost/replica:0/task:0", &devices_));
+ options, "/job:localhost/replica:0/task:0", &devices));
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
OptimizerOptions opts;
- device_mgr_.reset(new DeviceMgr(devices_));
+ device_mgr_.reset(new DeviceMgr(std::move(devices)));
pflr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts, default_thread_pool, nullptr /* cluster_flr */));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
- flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
- flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
- fdef_lib_ = lib_def_->ToProto();
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
@@ -192,13 +190,9 @@
}
FunctionLibraryRuntime* flr0_;
- FunctionLibraryRuntime* flr1_;
- FunctionLibraryRuntime* flr2_;
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
- FunctionDefLibrary fdef_lib_;
};
TEST_F(FunctionLibraryRuntimeTest, DefaultThreadpool) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 81fea31..5152d97 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -907,9 +907,9 @@
const int BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength = 1000;
const int BaseGPUDeviceFactory::InterconnectMap::kStreamExecutorStrength = 1;
-Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
- const string& name_prefix,
- std::vector<Device*>* devices) {
+Status BaseGPUDeviceFactory::CreateDevices(
+ const SessionOptions& options, const string& name_prefix,
+ std::vector<std::unique_ptr<Device>>* devices) {
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
se::Platform* gpu_manager = GPUMachineManager();
if (gpu_manager == nullptr) {
@@ -1073,12 +1073,10 @@
// LINT.ThenChange(//tensorflow/python/platform/test.py)
}
-Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
- const string& name_prefix,
- TfGpuId tf_gpu_id,
- int64 memory_limit,
- const DeviceLocality& dev_locality,
- std::vector<Device*>* devices) {
+Status BaseGPUDeviceFactory::CreateGPUDevice(
+ const SessionOptions& options, const string& name_prefix, TfGpuId tf_gpu_id,
+ int64 memory_limit, const DeviceLocality& dev_locality,
+ std::vector<std::unique_ptr<Device>>* devices) {
CHECK_GE(tf_gpu_id.value(), 0);
const string device_name =
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
@@ -1108,7 +1106,7 @@
// different (which should be an error).
//
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
- BaseGPUDevice* gpu_device = CreateGPUDevice(
+ std::unique_ptr<BaseGPUDevice> gpu_device = CreateGPUDevice(
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
@@ -1116,7 +1114,7 @@
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
<< GetShortDeviceDescription(platform_gpu_id, desc) << ")";
TF_RETURN_IF_ERROR(gpu_device->Init(options));
- devices->push_back(gpu_device);
+ devices->push_back(std::move(gpu_device));
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 674e838..d002d02 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -166,7 +166,7 @@
class BaseGPUDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override;
+ std::vector<std::unique_ptr<Device>>* devices) override;
struct InterconnectMap {
// Name of interconnect technology, if known.
@@ -207,15 +207,13 @@
Status CreateGPUDevice(const SessionOptions& options,
const string& name_prefix, TfGpuId tf_gpu_id,
int64 memory_limit, const DeviceLocality& dev_locality,
- std::vector<Device*>* devices);
+ std::vector<std::unique_ptr<Device>>* devices);
- virtual BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
- const string& name, Bytes memory_limit,
- const DeviceLocality& dev_locality,
- TfGpuId tf_gpu_id,
- const string& physical_device_desc,
- Allocator* gpu_allocator,
- Allocator* cpu_allocator) = 0;
+ virtual std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
+ const SessionOptions& options, const string& name, Bytes memory_limit,
+ const DeviceLocality& dev_locality, TfGpuId tf_gpu_id,
+ const string& physical_device_desc, Allocator* gpu_allocator,
+ Allocator* cpu_allocator) = 0;
// Returns into 'ids' the list of valid platform GPU ids, in the order that
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
index e1aaf95..8dc7197 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc
@@ -59,15 +59,14 @@
class GPUDeviceFactory : public BaseGPUDeviceFactory {
private:
- BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
- const string& name, Bytes memory_limit,
- const DeviceLocality& locality,
- TfGpuId tf_gpu_id,
- const string& physical_device_desc,
- Allocator* gpu_allocator,
- Allocator* cpu_allocator) override {
- return new GPUDevice(options, name, memory_limit, locality, tf_gpu_id,
- physical_device_desc, gpu_allocator, cpu_allocator);
+ std::unique_ptr<BaseGPUDevice> CreateGPUDevice(
+ const SessionOptions& options, const string& name, Bytes memory_limit,
+ const DeviceLocality& locality, TfGpuId tf_gpu_id,
+ const string& physical_device_desc, Allocator* gpu_allocator,
+ Allocator* cpu_allocator) override {
+ return absl::make_unique<GPUDevice>(options, name, memory_limit, locality,
+ tf_gpu_id, physical_device_desc,
+ gpu_allocator, cpu_allocator);
}
};
@@ -108,7 +107,7 @@
class GPUCompatibleCPUDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override {
+ std::vector<std::unique_ptr<Device>>* devices) override {
int n = 1;
auto iter = options.config.device_count().find("CPU");
if (iter != options.config.device_count().end()) {
@@ -116,7 +115,7 @@
}
for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:CPU:", i);
- devices->push_back(new GPUCompatibleCPUDevice(
+ devices->push_back(absl::make_unique<GPUCompatibleCPUDevice>(
options, name, Bytes(256 << 20), DeviceLocality(), cpu_allocator()));
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_on_non_gpu_machine_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_on_non_gpu_machine_test.cc
index 75be6d6..58656ec 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_on_non_gpu_machine_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_on_non_gpu_machine_test.cc
@@ -33,7 +33,7 @@
TEST(GPUDeviceOnNonGPUMachineTest, CreateGPUDevicesOnNonGPUMachine) {
SessionOptions opts;
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, "/job:localhost/replica:0/task:0", &devices));
EXPECT_TRUE(devices.empty());
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index 3629409..ae623b2 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -88,7 +88,7 @@
TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) {
SessionOptions opts = MakeSessionOptions("0,abc");
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@@ -97,7 +97,7 @@
TEST_F(GPUDeviceTest, InvalidGpuId) {
SessionOptions opts = MakeSessionOptions("100");
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@@ -107,7 +107,7 @@
TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) {
SessionOptions opts = MakeSessionOptions("0,0");
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@@ -117,7 +117,7 @@
TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) {
SessionOptions opts = MakeSessionOptions("0", 0.1, 1, {{}});
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@@ -129,7 +129,7 @@
// device_count is 0, but with one entry in visible_device_list and one
// (empty) VirtualDevices messages.
SessionOptions opts = MakeSessionOptions("0", 0, 0, {{}});
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::UNKNOWN);
@@ -141,7 +141,7 @@
// Single entry in visible_device_list with two (empty) VirtualDevices
// messages.
SessionOptions opts = MakeSessionOptions("0", 0, 8, {{}, {}});
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::UNKNOWN);
@@ -155,7 +155,7 @@
// Three entries in visible_device_list with two (empty) VirtualDevices
// messages.
SessionOptions opts = MakeSessionOptions("0,1", 0, 8, {{}});
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INVALID_ARGUMENT);
@@ -169,39 +169,36 @@
TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) {
// It'll create single virtual device when the virtual device config is empty.
SessionOptions opts = MakeSessionOptions("0");
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
- gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) {
// It'll create single virtual device for the gpu in question when
// memory_limit_mb is unset.
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{}});
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_GE(devices[0]->attributes().memory_limit(), 0);
- gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) {
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}});
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(1, devices.size());
EXPECT_EQ(123 << 20, devices[0]->attributes().memory_limit());
- gtl::STLDeleteElements(&devices);
}
TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}});
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
EXPECT_EQ(2, devices.size());
@@ -219,7 +216,6 @@
devices[1]->attributes().locality().links().link(0).type());
EXPECT_EQ(BaseGPUDeviceFactory::InterconnectMap::kSameDeviceStrength,
devices[1]->attributes().locality().links().link(0).strength());
- gtl::STLDeleteElements(&devices);
}
// Enabling unified memory on pre-Pascal GPUs results in an initialization
@@ -236,7 +232,7 @@
opts.config.mutable_gpu_options()
->mutable_experimental()
->set_use_unified_memory(true);
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices);
EXPECT_EQ(status.code(), error::INTERNAL);
@@ -259,7 +255,7 @@
}
SessionOptions opts = MakeSessionOptions("0", kGpuMemoryFraction);
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_ASSERT_OK(DeviceFactory::GetFactory("GPU")->CreateDevices(
opts, kDeviceNamePrefix, &devices));
ASSERT_EQ(1, devices.size());
@@ -278,8 +274,6 @@
(memory_limit >> 20) << 20);
EXPECT_NE(ptr, nullptr);
allocator->DeallocateRaw(ptr);
-
- gtl::STLDeleteElements(&devices);
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index 2144eea..f0656ff 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include <algorithm>
+#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -217,7 +218,7 @@
<< " num_devices_per_worker=" << num_devices_per_worker;
int total_num_devices = num_workers * num_devices_per_worker;
device_type_ = device_type;
- std::vector<Device*> local_devices;
+ std::vector<std::unique_ptr<Device>> local_devices;
SessionOptions sess_opts;
sess_opts.env = Env::Default();
Bytes mem_limit(4 << 20);
@@ -227,7 +228,7 @@
if (device_type == DEVICE_CPU) {
string dev_name = strings::StrCat("/job:worker/replica:0/task:", wi,
"/device:CPU:", di);
- local_devices.push_back(new ThreadPoolDevice(
+ local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
int dev_idx = (wi * num_devices_per_worker) + di;
@@ -235,7 +236,7 @@
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
"than one ring node.";
} else {
- local_devices.push_back(gpu_devices_[dev_idx]);
+ local_devices.push_back(std::move(gpu_devices_[dev_idx]));
}
} else {
LOG(FATAL) << "Unsupported device_type " << device_type;
@@ -243,7 +244,7 @@
}
}
if (!dev_mgr_ || device_type == DEVICE_CPU) {
- dev_mgr_.reset(new DeviceMgr(local_devices));
+ dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
}
if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
@@ -714,7 +715,7 @@
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
- std::vector<tensorflow::Device*> gpu_devices_;
+ std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_;
mutex mu_;
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
index 1f585a8..bdd6c0e 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
@@ -75,12 +75,12 @@
const int graph_def_version = g->versions().producer();
LocalExecutorParams params;
- params.device = device_;
+ params.device = device_.get();
params.function_library = nullptr;
params.create_kernel = [this, graph_def_version](const NodeDef& ndef,
OpKernel** kernel) {
- return CreateNonCachedKernel(device_, nullptr, ndef, graph_def_version,
- kernel);
+ return CreateNonCachedKernel(device_.get(), nullptr, ndef,
+ graph_def_version, kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
@@ -107,7 +107,7 @@
// run kernel destructors that may attempt to access state borrowed from
// `device_`, such as the resource manager.
exec_.reset();
- delete device_;
+ device_.reset();
delete pool_;
}
}
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
index 555b43f..b1557c5 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
@@ -55,7 +55,7 @@
private:
thread::ThreadPool* pool_ = nullptr;
- Device* device_ = nullptr;
+ std::unique_ptr<Device> device_ = nullptr;
Rendezvous* rendez_ = nullptr;
std::unique_ptr<Executor> exec_;
diff --git a/tensorflow/core/common_runtime/metrics.cc b/tensorflow/core/common_runtime/metrics.cc
new file mode 100644
index 0000000..f4c94ed
--- /dev/null
+++ b/tensorflow/core/common_runtime/metrics.cc
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/metrics.h"
+#include "tensorflow/core/lib/monitoring/counter.h"
+
+namespace tensorflow {
+
+namespace {
+
+auto* graph_runs = monitoring::Counter<0>::New(
+ "/tensorflow/core/graph_runs",
+ "The number of graph executions used to collect "
+ "/tensorflow/core/graph_run_time_usecs");
+
+auto* graph_run_time_usecs = monitoring::Counter<0>::New(
+ "/tensorflow/core/graph_run_time_usecs",
+ "The total time spent on executing graphs in microseconds.");
+} // namespace
+
+void UpdateGraphExecTime(const uint64 running_time_usecs) {
+ if (running_time_usecs > 0) {
+ graph_runs->GetCell()->IncrementBy(1);
+ graph_run_time_usecs->GetCell()->IncrementBy(running_time_usecs);
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/metrics.h b/tensorflow/core/common_runtime/metrics.h
new file mode 100644
index 0000000..d3430c9
--- /dev/null
+++ b/tensorflow/core/common_runtime/metrics.h
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+void UpdateGraphExecTime(const uint64 running_time_usecs);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_METRICS_H_
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 009f905..04e77e5 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -92,7 +92,7 @@
class DummyFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override {
+ std::vector<std::unique_ptr<Device>>* devices) override {
return Status::OK();
}
};
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index cce2308..21cb621 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -62,9 +62,12 @@
SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 2});
+ std::vector<std::unique_ptr<Device>> devices;
TF_CHECK_OK(DeviceFactory::AddDevices(options, "/job:a/replica:0/task:0",
- &devices_));
- device_mgr_.reset(new DeviceMgr(devices_));
+ &devices));
+ device0_ = devices[0].get();
+ device1_ = devices[1].get();
+ device_mgr_.reset(new DeviceMgr(std::move(devices)));
FunctionDefLibrary proto;
for (const auto& fdef : flib) *(proto.add_function()) = fdef;
lib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), proto));
@@ -138,8 +141,9 @@
return Status::OK();
}
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
+ Device* device0_ = nullptr; // Not owned. (Owned by device_mgr_.)
+ Device* device1_ = nullptr; // Not owned. (Owned by device_mgr_.)
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<TestClusterFLR> cluster_flr_;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
@@ -165,16 +169,16 @@
FunctionLibraryRuntime* flr =
proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:0");
EXPECT_NE(flr, nullptr);
- EXPECT_EQ(flr->device(), devices_[0]);
+ EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/device:CPU:0");
EXPECT_NE(flr, nullptr);
- EXPECT_EQ(flr->device(), devices_[0]);
+ EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/device:CPU:0");
EXPECT_NE(flr, nullptr);
- EXPECT_EQ(flr->device(), devices_[0]);
+ EXPECT_EQ(flr->device(), device0_);
flr = proc_flr_->GetFLR("/job:a/replica:0/task:0/cpu:1");
EXPECT_NE(flr, nullptr);
- EXPECT_EQ(flr->device(), devices_[1]);
+ EXPECT_EQ(flr->device(), device1_);
flr = proc_flr_->GetFLR("abc");
EXPECT_EQ(flr, nullptr);
rendezvous_->Unref();
diff --git a/tensorflow/core/common_runtime/renamed_device.cc b/tensorflow/core/common_runtime/renamed_device.cc
index 56766a8..45541c3 100644
--- a/tensorflow/core/common_runtime/renamed_device.cc
+++ b/tensorflow/core/common_runtime/renamed_device.cc
@@ -14,15 +14,14 @@
==============================================================================*/
#include "tensorflow/core/common_runtime/renamed_device.h"
+#include "absl/memory/memory.h"
namespace tensorflow {
-// TODO(saeta): Convert to returning a std::unique_ptr?
/* static */
-Device* RenamedDevice::NewRenamedDevice(const string& new_base,
- Device* underlying,
- bool owns_underlying,
- bool isolate_session_state) {
+std::unique_ptr<Device> RenamedDevice::NewRenamedDevice(
+ const string& new_base, Device* underlying, bool owns_underlying,
+ bool isolate_session_state) {
DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
DeviceNameUtils::ParsedName underlying_parsed_name =
@@ -36,8 +35,9 @@
parsed_name.id);
DeviceAttributes attributes(underlying->attributes());
attributes.set_name(name);
- return new RenamedDevice(underlying, attributes, owns_underlying,
- isolate_session_state);
+ // Call absl::WrapUnique to access private constructor.
+ return absl::WrapUnique(new RenamedDevice(
+ underlying, attributes, owns_underlying, isolate_session_state));
}
RenamedDevice::RenamedDevice(Device* underlying,
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index c00789a..6d24f49 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -28,9 +28,10 @@
// session.
class RenamedDevice : public Device {
public:
- static Device* NewRenamedDevice(const string& new_base, Device* underlying,
- bool owns_underlying,
- bool isolate_session_state);
+ static std::unique_ptr<Device> NewRenamedDevice(const string& new_base,
+ Device* underlying,
+ bool owns_underlying,
+ bool isolate_session_state);
~RenamedDevice() override;
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index b1fe928..092f15e 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -290,7 +290,7 @@
col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
col_ctx_->output, 0 /*dev_to_dev_stream_index*/,
- [this, ¬e, &status](const Status& s) {
+ [¬e, &status](const Status& s) {
status.Update(s);
note.Notify();
});
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index a271bf7..7feb29a 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/common_runtime/ring_reducer.h"
#include <algorithm>
+#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/collective_rma_local.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -157,7 +158,7 @@
InitGPUDevices();
#endif
device_type_ = device_type;
- std::vector<Device*> local_devices;
+ std::vector<std::unique_ptr<Device>> local_devices;
SessionOptions sess_opts;
sess_opts.env = Env::Default();
Bytes mem_limit(4 << 20);
@@ -167,7 +168,7 @@
if (device_type == DEVICE_CPU) {
string dev_name =
strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di);
- local_devices.push_back(new ThreadPoolDevice(
+ local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
} else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
int dev_idx = (wi * num_devices) + di;
@@ -175,7 +176,7 @@
LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
"than one ring node.";
} else {
- local_devices.push_back(gpu_devices_[dev_idx]);
+ local_devices.push_back(std::move(gpu_devices_[dev_idx]));
}
} else {
LOG(FATAL) << "Unsupported device_type " << device_type;
@@ -185,7 +186,7 @@
if (!dev_mgr_ || device_type == DEVICE_CPU) {
LOG(ERROR) << "resetting dev_mgr for " << local_devices.size()
<< " devices: ";
- dev_mgr_.reset(new DeviceMgr(local_devices));
+ dev_mgr_.reset(new DeviceMgr(std::move(local_devices)));
}
if (!gpu_ring_order_) gpu_ring_order_.reset(new string());
dev_resolver_.reset(new DeviceResolverLocal(dev_mgr_.get()));
@@ -544,7 +545,7 @@
std::unique_ptr<DeviceResolverLocal> dev_resolver_;
std::vector<DeviceInstance*> instances_;
CollectiveParams col_params_;
- std::vector<tensorflow::Device*> gpu_devices_;
+ std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
std::unique_ptr<string> gpu_ring_order_;
mutex mu_;
diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc
index c06a403..f9cbb81 100644
--- a/tensorflow/core/common_runtime/threadpool_device_factory.cc
+++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc
@@ -13,12 +13,13 @@
limitations under the License.
==============================================================================*/
-// Register a factory that provides CPU devices.
-#include "tensorflow/core/common_runtime/threadpool_device.h"
-
#include <vector>
+
+// Register a factory that provides CPU devices.
+#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/process_state.h"
+#include "tensorflow/core/common_runtime/threadpool_device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/numa.h"
#include "tensorflow/core/public/session_options.h"
@@ -29,7 +30,7 @@
class ThreadPoolDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override {
+ std::vector<std::unique_ptr<Device>>* devices) override {
int num_numa_nodes = port::NUMANumNodes();
int n = 1;
auto iter = options.config.device_count().find("CPU");
@@ -38,7 +39,7 @@
}
for (int i = 0; i < n; i++) {
string name = strings::StrCat(name_prefix, "/device:CPU:", i);
- ThreadPoolDevice* tpd = nullptr;
+ std::unique_ptr<ThreadPoolDevice> tpd;
if (options.config.experimental().use_numa_affinity()) {
int numa_node = i % num_numa_nodes;
if (numa_node != i) {
@@ -49,15 +50,15 @@
}
DeviceLocality dev_locality;
dev_locality.set_numa_node(numa_node);
- tpd = new ThreadPoolDevice(
+ tpd = absl::make_unique<ThreadPoolDevice>(
options, name, Bytes(256 << 20), dev_locality,
ProcessState::singleton()->GetCPUAllocator(numa_node));
} else {
- tpd = new ThreadPoolDevice(
+ tpd = absl::make_unique<ThreadPoolDevice>(
options, name, Bytes(256 << 20), DeviceLocality(),
ProcessState::singleton()->GetCPUAllocator(port::kNUMANoAffinity));
}
- devices->push_back(tpd);
+ devices->push_back(std::move(tpd));
}
return Status::OK();
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 8183247..e388d3e 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -425,6 +425,7 @@
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:metrics",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/debug",
@@ -624,6 +625,7 @@
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
index 4eed856..40b18d3 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc
@@ -29,7 +29,8 @@
namespace tensorflow {
namespace {
-static Device* NewDevice(const string& type, const string& name) {
+static std::unique_ptr<Device> NewDevice(const string& type,
+ const string& name) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@@ -40,7 +41,7 @@
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value
- return new FakeDevice(attr);
+ return absl::make_unique<FakeDevice>(attr);
}
class FakeWorker : public TestWorkerInterface {
@@ -156,16 +157,16 @@
void DefineWorker(const ConfigProto& config, const string& worker_name,
const string& device_type, int num_devices) {
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
}
- DeviceMgr* dev_mgr = new DeviceMgr(devices);
+ DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
- for (auto d : devices) {
+ for (auto* d : dev_mgr->ListDevices()) {
dv->push_back(d->name());
}
DeviceResolverDistributed* dev_res =
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 33e1c8f..26f722a 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -41,7 +41,8 @@
namespace tensorflow {
namespace {
-static Device* NewDevice(const string& type, const string& name) {
+static std::unique_ptr<Device> NewDevice(const string& type,
+ const string& name) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@@ -52,7 +53,7 @@
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(3); // a non-default value
- return new FakeDevice(attr);
+ return absl::make_unique<FakeDevice>(attr);
}
static int64 kStepId = 123;
@@ -211,16 +212,16 @@
void DefineWorker(const ConfigProto& config, const string& worker_name,
const string& device_type, int num_devices) {
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i)));
}
- DeviceMgr* dev_mgr = new DeviceMgr(devices);
+ DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
- for (auto d : devices) {
+ for (auto d : dev_mgr->ListDevices()) {
dv->push_back(d->name());
}
DeviceResolverDistributed* dev_res =
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc
index ae44b98..842a2b3 100644
--- a/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed_test.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
+#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/distributed_runtime/test_utils.h"
#include "tensorflow/core/lib/core/notification.h"
@@ -41,8 +42,8 @@
// Create a fake 'Device' whose only interesting attribute is a non-default
// DeviceLocality.
-static Device* NewDevice(const string& type, const string& name,
- int numa_node) {
+static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
+ int numa_node) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
@@ -53,7 +54,7 @@
attr.set_name(name);
attr.set_device_type(type);
attr.mutable_locality()->set_numa_node(numa_node);
- return new FakeDevice(attr);
+ return absl::make_unique<FakeDevice>(attr);
}
// Create a fake WorkerInterface that responds to requests without RPCs,
@@ -151,19 +152,19 @@
void DefineWorker(const string& worker_name, const string& device_type,
int num_devices) {
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < num_devices; ++i) {
devices.push_back(NewDevice(
device_type,
strings::StrCat(worker_name, "/device:", device_type, ":", i), i));
}
- DeviceMgr* dev_mgr = new DeviceMgr(devices);
+ DeviceMgr* dev_mgr = new DeviceMgr(std::move(devices));
TestableDeviceResolverDistributed* dev_res =
new TestableDeviceResolverDistributed(dev_mgr, &wc_, worker_name);
resolvers_[worker_name] = dev_res;
device_mgrs_.push_back(dev_mgr);
std::vector<string>* dv = &dev_by_task_[worker_name];
- for (auto d : devices) {
+ for (auto* d : dev_mgr->ListDevices()) {
dv->push_back(d->name());
}
FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res);
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index 055e5df..55b2657 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -69,6 +69,7 @@
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
+ "@com_google_absl//absl/memory",
],
)
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 5b0a420..13c959d 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/distributed_runtime/eager/eager_service_impl.h"
+#include "absl/memory/memory.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -87,7 +88,7 @@
return tensorflow::errors::Internal(
"invalid eager env_ or env_->rendezvous_mgr.");
}
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
// TODO(nareshmodi): Correctly set the SessionOptions.
@@ -97,12 +98,12 @@
request->server_def().task_index()),
&devices));
response->mutable_device_attributes()->Reserve(devices.size());
- for (auto& d : devices) {
+ for (const auto& d : devices) {
*response->add_device_attributes() = d->attributes();
}
- std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
- new tensorflow::DeviceMgr(devices));
+ std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
+ absl::make_unique<DeviceMgr>(std::move(devices));
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
auto session_name = strings::StrCat("eager_", request->rendezvous_id());
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index 5ba522c..7a1463e 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -68,12 +68,9 @@
worker_env_.rendezvous_mgr = &rendezvous_mgr_;
worker_env_.session_mgr = session_mgr_.get();
- Device* device = DeviceFactory::NewDevice(
- "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0");
-
- worker_env_.local_devices = {device};
-
- device_mgr_.reset(new DeviceMgr(worker_env_.local_devices));
+ device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
+ "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
+ worker_env_.local_devices = device_mgr_->ListDevices();
worker_env_.device_mgr = device_mgr_.get();
}
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 3944668..ee5823e 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
+#include <chrono> // NOLINT(build/c++11)
#include <vector>
#include "tensorflow/core/common_runtime/build_graph_options.h"
@@ -25,6 +26,7 @@
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/memory_types.h"
+#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
@@ -386,6 +388,7 @@
MutableRunGraphResponseWrapper* response,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done) {
+ const uint64 start_time_usecs = Env::Default()->NowMicros();
// Lookup an item. Holds one ref while executing.
Item* item = nullptr;
{
@@ -443,14 +446,16 @@
return;
}
- StartParallelExecutors(handle, step_id, item, rendezvous, ce_handle,
- collector, cost_graph, cancellation_manager,
- [item, rendezvous, ce_handle, done](const Status& s) {
- done(s);
- rendezvous->Unref();
- item->Unref();
- delete ce_handle;
- });
+ StartParallelExecutors(
+ handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
+ cancellation_manager,
+ [item, rendezvous, ce_handle, done, start_time_usecs](const Status& s) {
+ done(s);
+ UpdateGraphExecTime(Env::Default()->NowMicros() - start_time_usecs);
+ rendezvous->Unref();
+ item->Unref();
+ delete ce_handle;
+ });
}
void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index d122016..273709a 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -105,6 +105,7 @@
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index 456c30e..781b7d6 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -53,30 +53,58 @@
}
return Status::OK();
}
+
} // namespace
-Status NewHostPortGrpcChannel(const string& target,
- SharedGrpcChannelPtr* channel_pointer) {
- // Minimally ensure that the target is valid
- TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
-
+::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
// TODO(mrry): Implement secure channels.
::grpc::ChannelArguments args;
args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
// NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
// on connection failure, which makes our tests time out.
args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
+ if (rpc_options != nullptr) {
+ if (rpc_options->compression_algorithm() == "deflate") {
+ args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
+ args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
+ rpc_options->compression_level());
+ VLOG(5) << "Setting GRPC compression : algo='"
+ << rpc_options->compression_algorithm()
+ << "' level=" << rpc_options->compression_level();
+ } else if (rpc_options->compression_algorithm() == "gzip") {
+ args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
+ args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
+ rpc_options->compression_level());
+ VLOG(5) << "Setting GRPC compression : algo='"
+ << rpc_options->compression_algorithm()
+ << "' level=" << rpc_options->compression_level();
+ } else if (!rpc_options->compression_algorithm().empty()) {
+ LOG(ERROR) << "Invalid compression algorithm: "
+ << rpc_options->compression_algorithm();
+ }
+ }
+ return args;
+}
+
+Status NewHostPortGrpcChannel(const string& target,
+ const RPCOptions* rpc_options,
+ SharedGrpcChannelPtr* channel_pointer) {
+ // Minimally ensure that the target is valid
+ TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
+
+ ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
*channel_pointer = ::grpc::CreateCustomChannel(
"dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
return Status::OK();
}
ChannelCreationFunction ConvertToChannelCreationFunction(
- const std::function<Status(string, SharedGrpcChannelPtr*)>&
- new_channel_func_ptr) {
+ const std::function<Status(string, const RPCOptions*,
+ SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
SharedGrpcChannelPtr channel_ptr;
- if (new_channel_func_ptr(target, &channel_ptr).ok()) {
+ if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
+ .ok()) {
return channel_ptr;
} else {
return nullptr;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
index 6fa99d7..57d1621 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -25,6 +25,7 @@
#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@@ -86,11 +87,14 @@
// Below here are internal-only functions.
+::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options);
+
ChannelCreationFunction ConvertToChannelCreationFunction(
- const std::function<Status(string, SharedGrpcChannelPtr*)>&
- new_channel_func_ptr);
+ const std::function<Status(string, const RPCOptions*,
+ SharedGrpcChannelPtr*)>& new_channel_func_ptr);
Status NewHostPortGrpcChannel(const string& target,
+ const RPCOptions* rpc_options,
SharedGrpcChannelPtr* channel_pointer);
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
index a814ef8..a6fae22 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
@@ -184,18 +184,39 @@
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
SharedGrpcChannelPtr mock_ptr;
- EXPECT_TRUE(NewHostPortGrpcChannel("127.0.0.1:2222", &mock_ptr).ok());
- EXPECT_TRUE(NewHostPortGrpcChannel("example.com:2222", &mock_ptr).ok());
- EXPECT_TRUE(NewHostPortGrpcChannel("fqdn.example.com.:2222", &mock_ptr).ok());
- EXPECT_TRUE(NewHostPortGrpcChannel("[2002:a9c:258e::]:2222", &mock_ptr).ok());
- EXPECT_TRUE(NewHostPortGrpcChannel("[::]:2222", &mock_ptr).ok());
+ EXPECT_TRUE(NewHostPortGrpcChannel("127.0.0.1:2222", /*rpc_options=*/nullptr,
+ &mock_ptr)
+ .ok());
+ EXPECT_TRUE(NewHostPortGrpcChannel("example.com:2222",
+ /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
+ EXPECT_TRUE(NewHostPortGrpcChannel("fqdn.example.com.:2222",
+ /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
+ EXPECT_TRUE(NewHostPortGrpcChannel("[2002:a9c:258e::]:2222",
+ /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
+ EXPECT_TRUE(
+ NewHostPortGrpcChannel("[::]:2222", /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
- EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:2222", &mock_ptr).ok());
- EXPECT_FALSE(NewHostPortGrpcChannel("127.0.0.1:2222/", &mock_ptr).ok());
- EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:", &mock_ptr).ok());
- EXPECT_FALSE(NewHostPortGrpcChannel("[::]/:2222", &mock_ptr).ok());
- EXPECT_FALSE(NewHostPortGrpcChannel("[::]:2222/", &mock_ptr).ok());
- EXPECT_FALSE(NewHostPortGrpcChannel("[::]:", &mock_ptr).ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel("example.com/abc:2222",
+ /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel("127.0.0.1:2222/",
+ /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
+ EXPECT_FALSE(NewHostPortGrpcChannel(
+ "example.com/abc:", /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
+ EXPECT_FALSE(
+ NewHostPortGrpcChannel("[::]/:2222", /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
+ EXPECT_FALSE(
+ NewHostPortGrpcChannel("[::]:2222/", /*rpc_options=*/nullptr, &mock_ptr)
+ .ok());
+ EXPECT_FALSE(
+ NewHostPortGrpcChannel("[::]:", /*rpc_options=*/nullptr, &mock_ptr).ok());
}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index ae722fd..cbd5cd9 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -18,6 +18,7 @@
#include <cstring>
#include <limits>
#include <memory>
+#include <vector>
#include "grpc/support/alloc.h"
#include "grpcpp/grpcpp.h"
@@ -156,10 +157,12 @@
string name_prefix =
strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
"/task:", server_def_.task_index());
- TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
- &master_env_.local_devices));
- worker_env_.local_devices = master_env_.local_devices;
- worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
+ std::vector<std::unique_ptr<Device>> devices;
+ TF_RETURN_IF_ERROR(
+ DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
+ worker_env_.device_mgr = new DeviceMgr(std::move(devices));
+ master_env_.local_devices = worker_env_.device_mgr->ListDevices();
+ worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
? new RpcRendezvousMgr(&worker_env_)
: rendezvous_mgr_func(&worker_env_);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index fdce1b1..1ad40fe 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -52,8 +52,9 @@
}
if (!master) {
SharedGrpcChannelPtr master_channel;
- TF_RETURN_IF_ERROR(NewHostPortGrpcChannel(
- options.target.substr(kSchemePrefixLength), &master_channel));
+ TF_RETURN_IF_ERROR(
+ NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
+ &options.config.rpc_options(), &master_channel));
master.reset(NewGrpcMaster(master_channel));
}
session->SetRemoteMaster(std::move(master));
@@ -384,8 +385,9 @@
Status GrpcSession::Reset(const SessionOptions& options,
const std::vector<string>& containers) {
SharedGrpcChannelPtr master_channel;
- TF_RETURN_IF_ERROR(NewHostPortGrpcChannel(
- options.target.substr(kSchemePrefixLength), &master_channel));
+ TF_RETURN_IF_ERROR(
+ NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength),
+ /*rpc_options=*/nullptr, &master_channel));
auto master = NewGrpcMaster(master_channel);
ResetRequest req;
for (const auto& c : containers) req.add_container(c);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index fc60199..ad0f8e5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -1066,4 +1066,31 @@
error::INTERNAL == status.code());
}
+TEST(SessionTest, TestCompression) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
+ SessionOptions options = Options(cluster->targets()[0], 100);
+ RPCOptions* rpc_options = options.config.mutable_rpc_options();
+ rpc_options->set_compression_algorithm("deflate");
+ rpc_options->set_compression_level(GRPC_COMPRESS_LEVEL_HIGH);
+
+ std::unique_ptr<Session> session(NewRemote(options));
+
+ static const float kTestValue = 409.1934f;
+ Graph graph(OpRegistry::Global());
+ Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
+ tensor.flat<float>()(0) = kTestValue;
+ Node* b = test::graph::Constant(&graph, tensor);
+ GraphDef gdef;
+ graph.ToGraphDef(&gdef);
+ RunOptions run_options;
+ TF_CHECK_OK(session->Create(run_options, gdef));
+
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(session->Run(inputs, {b->name()}, {}, &outputs));
+ ASSERT_EQ(1, outputs.size());
+ IsSingleFloatValue(outputs[0], kTestValue);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
index b8cb538..9fb9204 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
@@ -244,6 +244,15 @@
// Record "call" in active_ so that it can be aborted cleanly.
RegisterCall(call);
+ // RendezvousMgr already aborted, shouldn't send RPC call any more
+ if (!call->status().ok()) {
+ call->done()(call->status(), Args(), Args(), Tensor(), false);
+ session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
+ call->wi_ = nullptr;
+ get_call_freelist()->Release(call, session()->worker_cache.get());
+ return;
+ }
+
// Start "call".
Ref();
call->Start([this, call]() {
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
index 0323300..1c87fe9 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
@@ -42,8 +42,9 @@
WorkerCacheInterface* worker_cache = nullptr;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", NUM_DEVS});
- TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices_));
- device_mgr_.reset(new DeviceMgr(devices_));
+ std::vector<std::unique_ptr<Device>> devices;
+ TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
+ device_mgr_.reset(new DeviceMgr(std::move(devices)));
std::unique_ptr<DeviceResolverDistributed> dr(new DeviceResolverDistributed(
device_mgr_.get(), worker_cache, task_name));
std::unique_ptr<CollectiveParamResolverDistributed> cpr(
@@ -57,7 +58,6 @@
}
std::unique_ptr<RpcCollectiveExecutorMgr> cme_;
- std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
};
diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc
index 38833bd..29fe767 100644
--- a/tensorflow/core/distributed_runtime/session_mgr.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr.cc
@@ -78,13 +78,13 @@
if (isolate_session_state) {
// Create a private copy of the DeviceMgr for the WorkerSession.
- std::vector<Device*> renamed_devices;
+ std::vector<std::unique_ptr<Device>> renamed_devices;
for (Device* d : worker_env_->local_devices) {
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
worker_name, d, false, isolate_session_state));
}
- auto device_mgr = MakeUnique<DeviceMgr>(renamed_devices);
+ auto device_mgr = MakeUnique<DeviceMgr>(std::move(renamed_devices));
auto graph_mgr = MakeUnique<GraphMgr>(worker_env_, device_mgr.get());
worker_session.reset(
new WorkerSession(session, worker_name,
diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc
index 9919211..1ab0d20 100644
--- a/tensorflow/core/distributed_runtime/session_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc
@@ -46,11 +46,9 @@
SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(), factory_) {
- Device* device =
- FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0")
- .release();
- env_.local_devices = {device};
- device_mgr_.reset(new DeviceMgr(env_.local_devices));
+ device_mgr_ = absl::make_unique<DeviceMgr>(
+ FakeDevice::MakeCPU("/job:mnist/replica:0/task:0/device:fakecpu:0"));
+ env_.local_devices = device_mgr_->ListDevices();
env_.device_mgr = device_mgr_.get();
}
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 838f899..6809c27 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1241,6 +1241,16 @@
}
}
+std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const {
+ std::vector<string> function_names;
+ tf_shared_lock l(mu_);
+ function_names.reserve(function_defs_.size());
+ for (const auto& it : function_defs_) {
+ function_names.emplace_back(it.first);
+ }
+ return function_names;
+}
+
FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
FunctionDefLibrary lib;
tf_shared_lock l(mu_);
@@ -1357,12 +1367,12 @@
if (!grad_func_name.empty()) add_to_func_queue(grad_func_name);
}
- const FunctionDefLibrary library_proto = flib.ToProto();
- for (const auto& it : library_proto.function()) {
- const auto attr_it = it.attr().find(kExperimentalApiImplements);
- if (attr_it != it.attr().end()) {
+ for (const auto& func_name : flib.ListFunctionNames()) {
+ const auto& func_def = flib.Find(func_name);
+ const auto attr_it = func_def->attr().find(kExperimentalApiImplements);
+ if (attr_it != func_def->attr().end()) {
if (reachable_api_interface.contains(attr_it->second.s())) {
- reachable_funcs.insert(it.signature().name());
+ reachable_funcs.insert(func_name);
}
}
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 6792cf1..9cf4b0f 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -407,6 +407,9 @@
return function_defs_.size();
}
+ // Returns all the function names in the FunctionLibraryDefinition.
+ std::vector<string> ListFunctionNames() const LOCKS_EXCLUDED(mu_);
+
const OpRegistryInterface* default_registry() const {
return default_registry_;
}
diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc
index f57a79b..75d45fa 100644
--- a/tensorflow/core/framework/function_test.cc
+++ b/tensorflow/core/framework/function_test.cc
@@ -1213,6 +1213,17 @@
EXPECT_EQ(f3->DebugString(), f4->DebugString());
}
+TEST(FunctionLibraryDefinitionTest, FunctionNames) {
+ FunctionDefLibrary proto;
+ *proto.add_function() = test::function::XTimesTwo();
+ *proto.add_function() = test::function::WXPlusB();
+ const FunctionLibraryDefinition lib_def(OpRegistry::Global(), proto);
+
+ const std::vector<string> function_names = lib_def.ListFunctionNames();
+ const std::vector<string> expected = {"XTimesTwo", "WXPlusB"};
+ EXPECT_EQ(function_names, expected);
+}
+
TEST(FunctionLibraryDefinitionTest, GetAttr_FuncNoAttr) {
FunctionDefLibrary proto;
*proto.add_function() = test::function::XTimesTwo();
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 508a8d3..9f3204a 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -204,12 +204,19 @@
}
Status ResourceMgr::Cleanup(const string& container) {
+ {
+ tf_shared_lock l(mu_);
+ if (!gtl::FindOrNull(containers_, container)) {
+ // Nothing to cleanup.
+ return Status::OK();
+ }
+ }
Container* b = nullptr;
{
mutex_lock l(mu_);
auto iter = containers_.find(container);
if (iter == containers_.end()) {
- // Nothing to cleanup, it's OK.
+ // Nothing to cleanup, it's OK (concurrent cleanup).
return Status::OK();
}
b = iter->second;
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 6c6d98b..af0b123 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -65,7 +65,7 @@
class NeighborIter; // Declared below
class NodeIter; // Declared below
-class NodeProperties; // Defined in .cc
+struct NodeProperties; // Defined in .cc
class Node {
public:
diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
index 7b03ec3..7982b35 100644
--- a/tensorflow/core/grappler/BUILD
+++ b/tensorflow/core/grappler/BUILD
@@ -41,6 +41,7 @@
"//tensorflow/core:all_kernels",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 998bd59..c9ce63a 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -832,7 +832,7 @@
EXPECT_FALSE(
GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
- // Check GetTensorShapeProtoFromTensorProto() resturns correct values.
+ // Check GetTensorShapeProtoFromTensorProto() returns correct values.
{
std::vector<int64> shape_expected = {10, 20, 30, 40};
GetTensorProto(DT_INT32, {4}, shape_expected,
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index b9b240e..ae5200b 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -469,8 +469,8 @@
} else {
// Different device, no cached copy; transfer input_node to the
// curr_node's device.
- auto send_and_recv =
- CreateSendRecv(input_node, curr_node, input_node_name);
+ auto send_and_recv = CreateSendRecv(input_node, curr_node, input_node,
+ input_node_name);
// Note that CreateSendRecv() already connected input/output between
// _Send and _Recv ops.
const auto* send = send_and_recv.first;
@@ -608,7 +608,8 @@
}
std::pair<const NodeDef*, const NodeDef*> VirtualScheduler::CreateSendRecv(
- const NodeDef* from, const NodeDef* to, const string& input_name) {
+ const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
+ const string& input_name) {
CHECK(!initialized_) << "CreateSendRecv is called after Init().";
// Connect "from" node to "to" node with _Send and _Recv such that
@@ -639,10 +640,14 @@
send->set_device(ChannelDeviceName(from, to));
auto& send_attr = *(send->mutable_attr());
send_attr[kAttrInputSrc].set_s(input_name);
- // Use input_name as tensor_name.
- send_attr[kAttrTensorName].set_s(input_name);
send_attr[kAttrSrcDevice].set_s(DeviceName(from));
send_attr[kAttrDstDevice].set_s(DeviceName(to));
+ // GraphDef generated by AutoGrappler has tensor_name field when removing
+ // _Send/_Recv nodes.
+ if (input_node->attr().count(kAttrTensorName)) {
+ send_attr[kAttrTensorName].set_s(
+ input_node->attr().at(kAttrTensorName).s());
+ }
// _Recv op.
auto* recv = new NodeDef();
@@ -652,8 +657,10 @@
recv->set_device(DeviceName(to));
auto& recv_attr = *(recv->mutable_attr());
recv_attr[kAttrInputSrc].set_s(input_name);
- // Use input_name as tensor_name.
- recv_attr[kAttrTensorName].set_s(input_name);
+ if (input_node->attr().count(kAttrTensorName)) {
+ recv_attr[kAttrTensorName].set_s(
+ input_node->attr().at(kAttrTensorName).s());
+ }
// NodeState for _Send op.
auto& send_node_state = GetNodeStateOrCreateIt(send);
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index 92e0a88..6a835f3 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -317,7 +317,8 @@
void MaybeUpdateInputOutput(const NodeDef* node);
NodeState& GetNodeStateOrCreateIt(const NodeDef* node);
std::pair<const NodeDef*, const NodeDef*> CreateSendRecv(
- const NodeDef* from, const NodeDef* to, const string& input_name);
+ const NodeDef* from, const NodeDef* to, const NodeDef* input_node,
+ const string& input_name);
string DeviceName(const NodeDef* node) const;
string SanitizedDeviceName(const NodeDef* node) const;
string ChannelDeviceName(const NodeDef* from, const NodeDef* to) const;
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.h b/tensorflow/core/grappler/graph_analyzer/sig_node.h
index 45c0ed3..66d290d 100644
--- a/tensorflow/core/grappler/graph_analyzer/sig_node.h
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node.h
@@ -178,7 +178,7 @@
// computed.
size_t GetTopoHash(int distance) const;
- // The the hash value for the highest computed distance. It must be previously
+ // The hash value for the highest computed distance. It must be previously
// computed.
size_t GetHighTopoHash() const {
CHECK(!topo_hash_.empty());
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index cf99f49..9224ee7 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -102,10 +102,11 @@
}
// Instantiate all variables for function library runtime creation.
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
- std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(devices));
+ Device* cpu_device = devices[0].get();
+ std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
FunctionLibraryDefinition function_library(OpRegistry::Global(),
graph_def.library());
Env* env = Env::Default();
@@ -124,7 +125,7 @@
new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
graph_def.versions().producer(),
&function_library, *optimizer_opts));
- FunctionLibraryRuntime* flr = pflr->GetFLR(devices[0]->name());
+ FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
// Create the GraphOptimizer to optimize the graph def.
GraphConstructorOptions graph_ctor_opts;
@@ -137,7 +138,7 @@
// Optimize the graph.
::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
- optimizer.Optimize(flr, env, devices[0], &graphptr, /*shape_map=*/nullptr);
+ optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
graphptr->ToGraphDef(output_graph_def);
// The default values of attributes might have been stripped by the optimizer.
@@ -519,7 +520,7 @@
}
if (!iter->second.has_tensor() ||
iter->second.tensor().string_val_size() != 1) {
- LOG(INFO) << "Unexected AttrValue proto: "
+ LOG(INFO) << "Unexpected AttrValue proto: "
<< iter->second.DebugString();
return nullptr;
}
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 0624839..38fc1ff 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -551,14 +551,15 @@
return false;
}
-bool IsFreeOfSideEffect(const NodeDef& node) {
+bool IsFreeOfSideEffect(const NodeDef& node,
+ const OpRegistryInterface* op_registry) {
// Placeholders must be preserved to keep the graph feedable.
if (IsPlaceholder(node)) {
return false;
}
const OpDef* op_def = nullptr;
const string& op_name = node.op();
- Status status = OpRegistry::Global()->LookUpOpDef(op_name, &op_def);
+ Status status = op_registry->LookUpOpDef(op_name, &op_def);
if (!status.ok()) {
return false;
}
@@ -582,6 +583,10 @@
return !ModifiesInputsInPlace(node);
}
+bool IsFreeOfSideEffect(const NodeDef& node) {
+ return IsFreeOfSideEffect(node, OpRegistry::Global());
+}
+
bool ModifiesInputsInPlace(const NodeDef& node) {
// Some nodes do in-place updates on regular tensor inputs.
string op_name = node.op();
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index bd286f2..67897e8 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -17,6 +17,7 @@
#define TENSORFLOW_CORE_GRAPPLER_OP_TYPES_H_
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -180,7 +181,9 @@
// value.
bool IsPersistent(const NodeDef& node);
-bool IsFreeOfSideEffect(const NodeDef& node);
+bool IsFreeOfSideEffect(const NodeDef& node,
+ const OpRegistryInterface* op_registry);
+bool IsFreeOfSideEffect(const NodeDef& node); // use OpRegistry::Global()
// Returns true if the takes a tensor reference as input, or if looking up its
// OpDef failed.
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index b6f989f..8e66295 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -142,7 +142,6 @@
":graph_optimizer",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
@@ -150,6 +149,7 @@
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:functions",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index cf294cd..e676323 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2309,7 +2309,9 @@
~SimplifyAggregation() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsAggregate(*node) && NumNonControlInputs(*node) > 0;
+ return IsAggregate(*node) && NumNonControlInputs(*node) > 0 &&
+ GetDataTypeFromAttr(*node, "T") !=
+ DT_VARIANT; // TODO(b/119787146): Enable for variants.
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
@@ -2407,7 +2409,7 @@
ctx().graph_properties->GetInputProperties(node->name())[1];
for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
if (pow_props.shape().dim(i).size() < 0) {
- // skip if p is is not fully defined.
+ // skip if p is not fully defined.
return Status::OK();
}
}
@@ -2459,7 +2461,7 @@
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
for (int i = 0; i < value_props.shape().dim_size(); ++i) {
if (value_props.shape().dim(i).size() < 0) {
- // skip if b is is not fully defined.
+ // skip if b is not fully defined.
return Status::OK();
}
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index b6286c4..35d2289 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
+#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -3793,5 +3794,31 @@
tensors[fCSlice2ToOut]);
}
+TEST_F(ArithmeticOptimizerTest, SimplifyAggregationBFloat16) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output cast = ops::Cast(s.WithOpName("cast"), x, DT_BFLOAT16);
+ Output add = ops::AddN(s.WithOpName("add"), {cast, cast});
+ Output id = ops::Identity(s.WithOpName("id"), add);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"id"};
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlySimplifyAggregation(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // Extra node created for multiplier.
+ EXPECT_EQ(5, output.node_size());
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<bfloat16>(tensors_expected[0], tensors[0]);
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index f99826d..22013ea 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -16,7 +16,9 @@
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include <unordered_map>
+#include <vector>
+#include "absl/memory/memory.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -218,13 +220,15 @@
const GrapplerItem& item)
: grappler_item_id_(item.id),
graph_version_(item.graph.versions().producer()),
+ opt_level_(opt_level),
function_library_(OpRegistry::Global(), item.graph.library()),
graph_view_(&item.graph) {
InitializeTrulyConstNodes(item);
- InitializeInlinedFunctions(opt_level, item);
InitializeFetchNodes(item);
}
+ const RewriterConfig::Toggle opt_level() const { return opt_level_; }
+
const FunctionLibraryDefinition& function_library() const {
return function_library_;
}
@@ -253,10 +257,6 @@
return fetch_nodes_.find(node_name) != fetch_nodes_.end();
}
- bool IsInlinedFunction(const string& name) const {
- return inlined_functions_.count(name) > 0;
- }
-
bool IsTrulyConst(const string& name) const {
return TrulyConstNode(name) != nullptr;
}
@@ -265,11 +265,6 @@
return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
}
- // Find inlining candidate by name. Return nullptr if not found.
- const FunctionDef* FindInlinedFunction(const string& name) const {
- return gtl::FindWithDefault(inlined_functions_, name, nullptr);
- }
-
const FunctionSpecialization* FindFunctionSpecialization(
const FunctionSpecializationSignature& sig) const {
return gtl::FindOrNull(specialized_functions_, sig);
@@ -310,26 +305,6 @@
}
}
- void InitializeInlinedFunctions(RewriterConfig::Toggle opt_level,
- const GrapplerItem& item) {
- bool aggressive = opt_level == RewriterConfig::AGGRESSIVE;
-
- for (const FunctionDef& func : item.graph.library().function()) {
- // Can't create IdentityN nodes with no input or output: skip these
- // functions for now.
- if (func.signature().input_arg_size() == 0 ||
- func.signature().output_arg_size() == 0) {
- continue;
- }
- bool marked_noinline = MarkedNoInline(func);
- bool marked_specialized = MarkedSpecialized(func);
-
- if (!marked_specialized && (!marked_noinline || aggressive)) {
- inlined_functions_[func.signature().name()] = &func;
- }
- }
- }
-
void InitializeFetchNodes(const GrapplerItem& item) {
for (const string& fetch : item.fetch) {
fetch_tensors_.insert(fetch);
@@ -343,19 +318,21 @@
DeviceAttributes attr;
attr.set_name("/device:CPU:0");
attr.set_device_type("CPU");
- Device* device = new FakeCPUDevice(env, attr);
- device_mgr_.reset(new DeviceMgr({device}));
+ std::vector<std::unique_ptr<Device>> devices;
+ devices.push_back(absl::make_unique<FakeCPUDevice>(env, attr));
+ device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
OptimizerOptions optimizer_opts;
optimizer_opts.set_do_function_inlining(true);
process_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), env, graph_version_, &function_library_,
optimizer_opts));
- flr_ = process_flr_->GetFLR(device->name());
+ flr_ = process_flr_->GetFLR(device_mgr_->ListDevices()[0]->name());
}
}
const string grappler_item_id_;
const int graph_version_;
+ const RewriterConfig::Toggle opt_level_;
FunctionLibraryDefinition function_library_;
// These fields initialized lazily only if needed.
@@ -363,8 +340,6 @@
std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
FunctionLibraryRuntime* flr_ = nullptr;
- // Functions that can be inlined into optimized graph.
- std::unordered_map<string, const FunctionDef*> inlined_functions_;
// Nodes that are Const and not in feed.
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
// Specialized functions.
@@ -387,6 +362,65 @@
TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
};
+// Returns a pointer to the called function definition iff the given node is
+// indeed a function call. Otherwise returns nullptr.
+const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
+ const NodeDef& node) {
+ // Check if a node does indirect function call via PartitionedCallOp.
+ if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
+ const AttrValue* func_attr = AttrSlice(node).Find("f");
+ return (func_attr != nullptr && func_attr->has_func())
+ ? ctx.function_library().Find(func_attr->func().name())
+ : nullptr;
+ }
+
+ // Check if the function op itself is a function name.
+ return ctx.function_library().Find(node.op());
+}
+
+// Returns true iff `node` is a direct function call of `func`, and we know how
+// to inline it into the main graph.
+bool IsInlinableDirectFunctionCall(const FunctionOptimizerContext& ctx,
+ const FunctionDef& func,
+ const NodeDef& node) {
+ // Indirect function calls (PartitionedCallOp) have automatic control
+ // dependencies and inlined separately from direct function calls.
+ bool is_direct_function_call = IsDirectFunctionCall(func, node);
+
+ // For direct function calls we insert IdentityN nodes before/after inlined
+ // function body to preserve function call semantics (all inputs evaluated
+ // before function evaluation starts, and all function body nodes finished
+ // before output consumed by other nodes).
+ bool has_inputs = func.signature().input_arg_size() > 0;
+ // TODO(ezhulenev): Relax constraint on output args?
+ bool has_outputs = func.signature().output_arg_size() > 0;
+
+ // Function must execute all the nodes in a function body that might have side
+ // effects. After inlining these nodes into the main graph, we can no longer
+ // guarantee that. For now we disable inlining functions with side effects.
+ //
+ // Attaching control dependency to the output IdentityN node is not safe,
+ // because it might be split or pruned in a later optimization pass.
+ //
+ // Indirect function calls (via PartitionedCallOp) have automatic dependency
+ // tracking, and allow us to safely inline functions with side effects.
+ bool free_of_side_effects =
+ std::all_of(func.node_def().begin(), func.node_def().end(),
+ [&ctx](const NodeDef& node) {
+ return IsFreeOfSideEffect(node, &ctx.function_library());
+ });
+
+ bool marked_noinline = MarkedNoInline(func);
+ bool marked_specialized = MarkedSpecialized(func);
+
+ // We ignore `_noinline` marker in aggressive mode.
+ bool aggressive = ctx.opt_level() == RewriterConfig::AGGRESSIVE;
+
+ return is_direct_function_call && has_inputs && has_outputs &&
+ free_of_side_effects && !marked_specialized &&
+ (!marked_noinline || aggressive);
+}
+
gtl::FlatSet<int> GetActiveOutputs(const NodeDef& node,
const FunctionOptimizerContext& ctx,
int size_hint = 0) {
@@ -605,6 +639,9 @@
// 2. Remove inputs corresponding to the pushed down consts.
RemovePushedDownConstInputs(specialization, specialized_func_node);
+ // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
+ // types, that must be in sync with updated function signature.
+
// 3. Update input types for the indirect function calls.
if (is_indirect_call) {
RemovePushedDownConstInputTypes(specialization, func_node,
@@ -802,13 +839,16 @@
return outputs;
}
-Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
- const FunctionOptimizerContext& ctx,
- const int graph_def_version, GraphDef* optimized_graph) {
- VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node);
+Status InlineDirectFunctionCall(const NodeDef& func_node,
+ const FunctionDef& func,
+ const FunctionOptimizerContext& ctx,
+ const int graph_def_version,
+ GraphDef* optimized_graph) {
+ VLOG(2) << "Inline direct function call: " << SummarizeNodeDef(func_node);
- // Specialized function call kernels might have behavior that is not
- // representable in a graph (e.g. runtime ops device placing).
+ // Indirect function calls (via PartitionedCallOp) have automatic control
+ // dependencies, and doesn't need IdentityN nodes before/after inlined
+ // function body, and we inline them separately.
if (!IsDirectFunctionCall(func, func_node)) {
return errors::InvalidArgument("Can't inline indirect function call");
}
@@ -874,14 +914,16 @@
// Make sure the node is placed.
func_body_node.set_device(func_node.device());
- // Check if a body node is itself a function.
+ // Check if a body node is itself a function call and can be inlined.
const FunctionDef* func_body_node_func =
- ctx.FindInlinedFunction(func_body_node.op());
- if (func_body_node_func != nullptr) {
+ FindFunctionCall(ctx, func_body_node);
+ if (func_body_node_func != nullptr &&
+ IsInlinableDirectFunctionCall(ctx, *func_body_node_func,
+ func_body_node)) {
// Recursively inline function calls.
- TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
- ctx, graph_def_version,
- optimized_graph));
+ TF_RETURN_IF_ERROR(
+ InlineDirectFunctionCall(func_body_node, *func_body_node_func, ctx,
+ graph_def_version, optimized_graph));
} else {
// Annotate the node with the function attributes.
for (const auto& attr : func.attr()) {
@@ -1012,8 +1054,6 @@
bool specialize_func = options_.enable_function_specialization;
for (const NodeDef& node : item.graph.node()) {
- const string op_name = node.op();
-
// Each node optimization can modify optimized graph only by adding new
// nodes, we can check node size to make sure that graph was not modified.
const int num_nodes_before = optimized_graph->node_size();
@@ -1042,11 +1082,13 @@
// 1. Inline symbolic gradients into the optimized graph. //
// ---------------------------------------------------------------------- //
- if (op_name == "SymbolicGradient" && inline_gradients) {
- // Inline symbolic gradients only if the corresponding function is inlined
+ if (IsSymbolicGradient(node) && inline_gradients) {
+ // Inline symbolic gradients only if the corresponding function is not
+ // marked as `_noinline`.
const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
- string f_name = f_attr != nullptr ? f_attr->func().name() : "";
- if (ctx.IsInlinedFunction(f_name)) {
+ const string f_name = f_attr != nullptr ? f_attr->func().name() : "";
+ const FunctionDef* func = ctx.function_library().Find(f_name);
+ if (func && !MarkedNoInline(*func)) {
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
InlineSymbolicGradient(node, &ctx, optimized_graph));
continue;
@@ -1054,28 +1096,33 @@
}
// ---------------------------------------------------------------------- //
- // 2. Inline or specialize direct function calls. //
+ // 2. Inline or specialize function calls. //
// ---------------------------------------------------------------------- //
- const FunctionDef* func = ctx.function_library().Find(op_name);
+ // Find if a node is a function call (direct or indirect).
+ const FunctionDef* func = FindFunctionCall(ctx, node);
+
if (func != nullptr) {
- // 2a. Inline it if it's allowed to do so.
- if (inline_func && ctx.IsInlinedFunction(op_name)) {
+ const string& func_name = func->signature().name();
+
+ // 2a. Inline direct function call if it's inlinable.
+ if (inline_func && IsInlinableDirectFunctionCall(ctx, *func, node)) {
// Inline function body into the optimized graph}
- TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- InlineFunction(node, *func, ctx, item.graph.versions().producer(),
- optimized_graph));
+ TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(InlineDirectFunctionCall(
+ node, *func, ctx, item.graph.versions().producer(),
+ optimized_graph));
continue;
}
- // Do not specialize if function has custom gradient.
- const string grad_func = ctx.function_library().FindGradient(op_name);
-
// 2b. Specialize it to it's instantiation context if can't be inlined,
// and it has something worth specializing.
bool specialization_worthy = IsParametrized(*func) ||
HasTrulyConstInputs(node, ctx) ||
HasUnusedOutputs(node, *func, ctx);
+
+ // Do not specialize if function has custom gradient.
+ const string grad_func = ctx.function_library().FindGradient(func_name);
+
if (specialize_func && grad_func.empty() && specialization_worthy) {
// TODO(ezhulenev): Specialize function call if input has a known shape.
// Specialize function body for its instantiation attributes and inputs.
@@ -1087,41 +1134,6 @@
}
// ---------------------------------------------------------------------- //
- // 3. Specialize indirect function calls through the PartitionedCallOp. //
- // ---------------------------------------------------------------------- //
-
- bool is_partitioned_call =
- IsPartitionedCall(node) || IsStatefulPartitionedCall(node);
-
- // We can only specialize PartitionedCall ops. Inlining is not supported.
- if (is_partitioned_call && specialize_func) {
- const AttrValue* func_attr = AttrSlice(node).Find("f");
- string indirect_func_name =
- (func_attr != nullptr && func_attr->has_func())
- ? func_attr->func().name()
- : "";
- const FunctionDef* indirect_func =
- ctx.function_library().Find(indirect_func_name);
-
- if (indirect_func != nullptr) {
- // Do not specialize if function has custom gradient.
- const string grad_func =
- ctx.function_library().FindGradient(indirect_func_name);
-
- // Specialize it to it's instantiation context.
- bool specialization_worthy =
- IsParametrized(*indirect_func) || HasTrulyConstInputs(node, ctx) ||
- HasUnusedOutputs(node, *indirect_func, ctx);
- if (grad_func.empty() && specialization_worthy) {
- TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(SpecializeFunction(
- node, *indirect_func, item.graph.versions().producer(), &ctx,
- optimized_graph));
- continue;
- }
- }
- }
-
- // ---------------------------------------------------------------------- //
// If we reached this point, node was not handled by any of the stages
// (inline, specialize), simply add a copy to the graph.
add_node_copy();
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 7dc62e2..f465350 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -119,6 +119,8 @@
"Exit",
"Exp",
"Expm1",
+ "FakeQuantWithMinMaxVars",
+ "FakeQuantWithMinMaxArgs",
"Fill",
"Floor",
"FloorDiv",
@@ -161,6 +163,8 @@
"PreventGradient",
"Prod",
"Polygamma",
+ "QuantizeAndDequantizeV2",
+ "QuantizeAndDequantizeV3",
"Pow",
"Real",
"RealDiv",
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 9336c4d..2977544 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -40,8 +40,8 @@
template <typename T>
bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
using RealType = typename Eigen::NumTraits<T>::Real;
- if (value > static_cast<double>(std::numeric_limits<RealType>::max()) ||
- value < static_cast<double>(std::numeric_limits<RealType>::min())) {
+ if (value > static_cast<double>(Eigen::NumTraits<RealType>::highest()) ||
+ value < static_cast<double>(Eigen::NumTraits<RealType>::lowest())) {
return false;
}
tensor->flat<T>()(0) = static_cast<T>(value);
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index 8cbff1c..e993391 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -16,10 +16,13 @@
#include "tensorflow/core/grappler/utils.h"
#include <unistd.h>
+#include <limits>
#include <memory>
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -441,6 +444,26 @@
BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123);
BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl);
+TEST_F(UtilsTest, SetTensorValueBFloat16) {
+ Tensor t(DT_BFLOAT16, TensorShape({}));
+ TF_ASSERT_OK(SetTensorValue(t.dtype(), 2, &t));
+ test::ExpectTensorEqual<bfloat16>(Tensor(bfloat16(2)), t);
+}
+
+TEST_F(UtilsTest, SetTensorValueBFloat16IntMax) {
+ Tensor t(DT_BFLOAT16, TensorShape({}));
+ TF_ASSERT_OK(SetTensorValue(t.dtype(), std::numeric_limits<int>::max(), &t));
+ test::ExpectTensorEqual<bfloat16>(
+ Tensor(bfloat16(std::numeric_limits<int>::max())), t);
+}
+
+TEST_F(UtilsTest, SetTensorValueBFloat16IntMin) {
+ Tensor t(DT_BFLOAT16, TensorShape({}));
+ TF_ASSERT_OK(SetTensorValue(t.dtype(), std::numeric_limits<int>::min(), &t));
+ test::ExpectTensorEqual<bfloat16>(
+ Tensor(bfloat16(std::numeric_limits<int>::min())), t);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ae76034..1efce93 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -199,8 +199,16 @@
name = "conv_2d",
hdrs = ["conv_2d.h"],
gpu_srcs = [
- "conv_2d_gpu.cu.cc",
"conv_2d.h",
+ "conv_2d_gpu.h",
+ "conv_2d_gpu_double.cu.cc",
+ "conv_2d_gpu_float.cu.cc",
+ "conv_2d_gpu_half.cu.cc",
+ "conv_2d_gpu_int.cu.cc",
+ "conv_2d_gpu_uint16.cu.cc",
+ "conv_2d_gpu_uint32.cu.cc",
+ "conv_2d_gpu_uint64.cu.cc",
+ "conv_2d_gpu_uint8.cu.cc",
],
deps = [
":eigen_helpers",
diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc
index 72155fd..47e10f5 100644
--- a/tensorflow/core/kernels/adjust_contrast_op.cc
+++ b/tensorflow/core/kernels/adjust_contrast_op.cc
@@ -320,13 +320,14 @@
int64 batch = outputs.dimension(0);
int64 image_size = outputs.dimension(1);
int64 channels = outputs.dimension(2);
- // Similar to the reduction case, a straighforward implementation of this
+ // Similar to the reduction case, a straightforward implementation of this
// does not utilize vectorization well because of the small channel size.
// This algorithm repeatedly increases the area to be copied, and leads to
// much better vectorinizations in the copy.
for (int64 i = 0; i < batch; i++) {
// Copy over the inputs into outputs in this batch. Effectively:
- // outputs(i, :, k) = inputs(i, k). An example of how this algorith works:
+ // outputs(i, :, k) = inputs(i, k). An example of how this algorithm
+ // works:
//
// x = float[1, 3], y = float[2048, 3]
// round 0
diff --git a/tensorflow/core/kernels/adjust_hue_op.cc b/tensorflow/core/kernels/adjust_hue_op.cc
index 6079aa7..52dec94 100644
--- a/tensorflow/core/kernels/adjust_hue_op.cc
+++ b/tensorflow/core/kernels/adjust_hue_op.cc
@@ -216,8 +216,8 @@
*context->device()->tensorflow_cpu_worker_threads();
Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
kCostPerChannel,
- [channel_count, &input_data, &output_data, delta_h](
- int64 start_channel, int64 end_channel) {
+ [&input_data, &output_data, delta_h](int64 start_channel,
+ int64 end_channel) {
const float* p = input_data.data() + start_channel * kChannelSize;
float* q = output_data.data() + start_channel * kChannelSize;
for (int i = start_channel; i < end_channel; i++) {
diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc
index 944564d..aa91235 100644
--- a/tensorflow/core/kernels/barrier_ops.cc
+++ b/tensorflow/core/kernels/barrier_ops.cc
@@ -180,7 +180,7 @@
// SQSS is closed, nothing is left in the incomplete set,
// the queue is not already marked as closed, and (most
// importantly), the queue has entries in it.
- [this, ctx, callback, component_index]() {
+ [this, ctx, callback]() {
if (!ctx->status().ok()) {
callback();
return;
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc
index 382c9d5..1587eb5 100644
--- a/tensorflow/core/kernels/control_flow_ops.cc
+++ b/tensorflow/core/kernels/control_flow_ops.cc
@@ -71,11 +71,13 @@
TF_CALL_ALL_TYPES(REGISTER_CPU_REF_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_CPU_REF_SWITCH);
+REGISTER_CPU_SWITCH(uint64);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_SWITCH);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_REF_SWITCH);
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_SWITCH);
+REGISTER_GPU_SWITCH(uint64);
#undef REGISTER_CPU_SWITCH
#undef REGISTER_CPU_REF_SWITCH
@@ -263,6 +265,7 @@
TF_CALL_QUANTIZED_TYPES(REGISTER_GPU_REF_KERNEL);
REGISTER_GPU_KERNEL(bool);
REGISTER_GPU_REF_KERNEL(bool);
+REGISTER_GPU_KERNEL(uint64);
#undef REGISTER_GPU_KERNEL
#undef REGISTER_GPU_REF_KERNEL
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index a6964b1..1bac2a1 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -162,7 +162,7 @@
merged_dims[1] = in.dimension(NDIMS - 2); // input filters
merged_dims[2] = in.dimension(NDIMS - 1); // output filters
- CHECK(dst_filter_format == FORMAT_OIHW)
+ DCHECK(dst_filter_format == FORMAT_OIHW)
<< "Unsupported destination filter format: "
<< ToString(dst_filter_format);
// Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
diff --git a/tensorflow/core/kernels/conv_2d_gpu.cu.cc b/tensorflow/core/kernels/conv_2d_gpu.h
similarity index 90%
rename from tensorflow/core/kernels/conv_2d_gpu.cu.cc
rename to tensorflow/core/kernels/conv_2d_gpu.h
index c6adf9e..8d11757 100644
--- a/tensorflow/core/kernels/conv_2d_gpu.cu.cc
+++ b/tensorflow/core/kernels/conv_2d_gpu.h
@@ -13,6 +13,9 @@
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
+#define TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -34,7 +37,7 @@
typedef Eigen::GpuDevice GPUDevice;
namespace functor {
-namespace {
+
template <typename T, bool conjugate>
struct maybe_conj {
__device__ static __inline__ T run(T x) {
@@ -75,8 +78,6 @@
}
};
-} // namespace
-
// TODO(mjanusz): Move this to a shared util file.
// A simple array that contains data that can be passed between CPU and GPU.
template <typename T, int IndexCount, T DefaultValue>
@@ -433,7 +434,7 @@
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- DCHECK(dst_filter_format == FORMAT_OIHW)
+ CHECK(dst_filter_format == FORMAT_OIHW)
<< "Unsupported output layout: " << ToString(dst_filter_format);
ShuffleInTensor3Simple<T, 2, 1, 0>
@@ -998,82 +999,9 @@
}
};
-template struct ShuffleAndReverse<Eigen::GpuDevice, float, 4, int>;
-template struct ShuffleAndReverse<Eigen::GpuDevice, Eigen::half, 4, int>;
-
-template struct ShuffleAndReverse<Eigen::GpuDevice, float, 4,
- Eigen::DenseIndex>;
-template struct ShuffleAndReverse<Eigen::GpuDevice, Eigen::half, 4,
- Eigen::DenseIndex>;
-
-template struct TransformDepth<Eigen::GpuDevice, float, int>;
-template struct TransformDepth<Eigen::GpuDevice, Eigen::half, int>;
-
-template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, uint8>;
-template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, uint16>;
-template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, uint32>;
-template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, uint64>;
-template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float4>;
-template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float2,
- /*conjugate=*/true>;
-template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, double2,
- /*conjugate=*/true>;
-template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, Eigen::half>;
-
-template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint8>;
-template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint16>;
-template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint32>;
-template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint64>;
-template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, float4>;
-template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, float2,
- /*conjugate=*/true>;
-template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, double2,
- /*conjugate=*/true>;
-
-// For 2d ops.
-template struct TransformFilter<Eigen::GpuDevice, double, int, 4>;
-template struct TransformFilter<Eigen::GpuDevice, float, int, 4>;
-template struct TransformFilter<Eigen::GpuDevice, Eigen::half, int, 4>;
-
-template struct ReverseTransformFilter<Eigen::GpuDevice, double, 4>;
-template struct ReverseTransformFilter<Eigen::GpuDevice, float, 4>;
-template struct ReverseTransformFilter<Eigen::GpuDevice, Eigen::half, 4>;
-
-template struct NHWCToNCHW<Eigen::GpuDevice, double, 4>;
-template struct NHWCToNCHW<Eigen::GpuDevice, float, 4>;
-template struct NHWCToNCHW<Eigen::GpuDevice, Eigen::half, 4>;
-
-template struct NCHWToNHWC<Eigen::GpuDevice, double, 4>;
-template struct NCHWToNHWC<Eigen::GpuDevice, float, 4>;
-template struct NCHWToNHWC<Eigen::GpuDevice, Eigen::half, 4>;
-
-template struct PadInput<Eigen::GpuDevice, int, int, 4>;
-template struct PadInput<Eigen::GpuDevice, double, int, 4>;
-template struct PadInput<Eigen::GpuDevice, float, int, 4>;
-template struct PadInput<Eigen::GpuDevice, Eigen::half, int, 4>;
-
-// For 3d ops.
-template struct TransformFilter<Eigen::GpuDevice, double, int, 5>;
-template struct TransformFilter<Eigen::GpuDevice, float, int, 5>;
-template struct TransformFilter<Eigen::GpuDevice, Eigen::half, int, 5>;
-
-template struct ReverseTransformFilter<Eigen::GpuDevice, double, 5>;
-template struct ReverseTransformFilter<Eigen::GpuDevice, float, 5>;
-template struct ReverseTransformFilter<Eigen::GpuDevice, Eigen::half, 5>;
-
-template struct NHWCToNCHW<Eigen::GpuDevice, double, 5>;
-template struct NHWCToNCHW<Eigen::GpuDevice, float, 5>;
-template struct NHWCToNCHW<Eigen::GpuDevice, Eigen::half, 5>;
-
-template struct NCHWToNHWC<Eigen::GpuDevice, double, 5>;
-template struct NCHWToNHWC<Eigen::GpuDevice, float, 5>;
-template struct NCHWToNHWC<Eigen::GpuDevice, Eigen::half, 5>;
-
-template struct PadInput<Eigen::GpuDevice, double, int, 5>;
-template struct PadInput<Eigen::GpuDevice, float, int, 5>;
-template struct PadInput<Eigen::GpuDevice, Eigen::half, int, 5>;
-
} // namespace functor
} // namespace tensorflow
#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
diff --git a/tensorflow/core/kernels/conv_2d_gpu_double.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_double.cu.cc
new file mode 100644
index 0000000..353d6d1
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d_gpu_double.cu.cc
@@ -0,0 +1,50 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_2d_gpu.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, double2,
+ /*conjugate=*/true>;
+
+template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, double2,
+ /*conjugate=*/true>;
+
+// For 2d ops.
+template struct TransformFilter<Eigen::GpuDevice, double, int, 4>;
+template struct ReverseTransformFilter<Eigen::GpuDevice, double, 4>;
+template struct NHWCToNCHW<Eigen::GpuDevice, double, 4>;
+template struct NCHWToNHWC<Eigen::GpuDevice, double, 4>;
+template struct PadInput<Eigen::GpuDevice, double, int, 4>;
+
+// For 3d ops.
+template struct TransformFilter<Eigen::GpuDevice, double, int, 5>;
+template struct ReverseTransformFilter<Eigen::GpuDevice, double, 5>;
+template struct NHWCToNCHW<Eigen::GpuDevice, double, 5>;
+template struct NCHWToNHWC<Eigen::GpuDevice, double, 5>;
+template struct PadInput<Eigen::GpuDevice, double, int, 5>;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_2d_gpu_float.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_float.cu.cc
new file mode 100644
index 0000000..21030dd
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d_gpu_float.cu.cc
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <algorithm>
+#include <array>
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_2d_gpu.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template struct ShuffleAndReverse<Eigen::GpuDevice, float, 4, int>;
+template struct ShuffleAndReverse<Eigen::GpuDevice, float, 4,
+ Eigen::DenseIndex>;
+
+template struct TransformDepth<Eigen::GpuDevice, float, int>;
+
+template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float4>;
+template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, float2,
+ /*conjugate=*/true>;
+
+template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, float4>;
+template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, float2,
+ /*conjugate=*/true>;
+
+// For 2d ops.
+template struct TransformFilter<Eigen::GpuDevice, float, int, 4>;
+template struct ReverseTransformFilter<Eigen::GpuDevice, float, 4>;
+template struct NHWCToNCHW<Eigen::GpuDevice, float, 4>;
+template struct NCHWToNHWC<Eigen::GpuDevice, float, 4>;
+template struct PadInput<Eigen::GpuDevice, float, int, 4>;
+
+// For 3d ops.
+template struct TransformFilter<Eigen::GpuDevice, float, int, 5>;
+template struct ReverseTransformFilter<Eigen::GpuDevice, float, 5>;
+template struct NHWCToNCHW<Eigen::GpuDevice, float, 5>;
+template struct NCHWToNHWC<Eigen::GpuDevice, float, 5>;
+template struct PadInput<Eigen::GpuDevice, float, int, 5>;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_2d_gpu_half.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_half.cu.cc
new file mode 100644
index 0000000..9483086
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d_gpu_half.cu.cc
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <algorithm>
+#include <array>
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_2d_gpu.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template struct ShuffleAndReverse<Eigen::GpuDevice, Eigen::half, 4, int>;
+template struct ShuffleAndReverse<Eigen::GpuDevice, Eigen::half, 4,
+ Eigen::DenseIndex>;
+
+template struct TransformDepth<Eigen::GpuDevice, Eigen::half, int>;
+
+template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, Eigen::half>;
+
+// For 2d ops.
+template struct TransformFilter<Eigen::GpuDevice, Eigen::half, int, 4>;
+template struct ReverseTransformFilter<Eigen::GpuDevice, Eigen::half, 4>;
+template struct NHWCToNCHW<Eigen::GpuDevice, Eigen::half, 4>;
+template struct NCHWToNHWC<Eigen::GpuDevice, Eigen::half, 4>;
+template struct PadInput<Eigen::GpuDevice, Eigen::half, int, 4>;
+
+// For 3d ops.
+template struct TransformFilter<Eigen::GpuDevice, Eigen::half, int, 5>;
+template struct ReverseTransformFilter<Eigen::GpuDevice, Eigen::half, 5>;
+template struct NHWCToNCHW<Eigen::GpuDevice, Eigen::half, 5>;
+template struct NCHWToNHWC<Eigen::GpuDevice, Eigen::half, 5>;
+template struct PadInput<Eigen::GpuDevice, Eigen::half, int, 5>;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_2d_gpu_int.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_int.cu.cc
new file mode 100644
index 0000000..901ce3e
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d_gpu_int.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <algorithm>
+#include <array>
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_2d_gpu.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+// For 2d ops.
+template struct PadInput<Eigen::GpuDevice, int, int, 4>;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_2d_gpu_uint16.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_uint16.cu.cc
new file mode 100644
index 0000000..e47532a
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d_gpu_uint16.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <algorithm>
+#include <array>
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_2d_gpu.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, uint16>;
+template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint16>;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_2d_gpu_uint32.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_uint32.cu.cc
new file mode 100644
index 0000000..56cd5dd
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d_gpu_uint32.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <algorithm>
+#include <array>
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_2d_gpu.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, uint32>;
+template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint32>;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_2d_gpu_uint64.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_uint64.cu.cc
new file mode 100644
index 0000000..045a664
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d_gpu_uint64.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <algorithm>
+#include <array>
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_2d_gpu.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, uint64>;
+template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint64>;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_2d_gpu_uint8.cu.cc b/tensorflow/core/kernels/conv_2d_gpu_uint8.cu.cc
new file mode 100644
index 0000000..2154178
--- /dev/null
+++ b/tensorflow/core/kernels/conv_2d_gpu_uint8.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <algorithm>
+#include <array>
+#include <limits>
+#include <utility>
+
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_2d_gpu.h"
+
+namespace tensorflow {
+
+namespace functor {
+
+template struct SwapDimension1And2InTensor3<Eigen::GpuDevice, uint8>;
+template struct SwapDimension0And2InTensor3<Eigen::GpuDevice, uint8>;
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index 87bbc30..bf98acd 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -30,6 +30,7 @@
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
@@ -549,8 +550,22 @@
tensorflow::GraphDef graph;
TF_ASSERT_OK(root.ToGraphDef(&graph));
+ // `FusedConv2D` is available only on CPU, and in this test we don't want to
+ // compare GPU vs CPU numbers, so place all nodes on CPU.
+ for (NodeDef& mutable_node : *graph.mutable_node()) {
+ mutable_node.set_device("/device:CPU:0");
+ }
+
+ // Disable Grappler constant folding for the test graphs.
+ tensorflow::SessionOptions session_options;
+ tensorflow::RewriterConfig* cfg =
+ session_options.config.mutable_graph_options()
+ ->mutable_rewrite_options();
+ cfg->set_constant_folding(tensorflow::RewriterConfig::OFF);
+
std::unique_ptr<tensorflow::Session> session(
- tensorflow::NewSession(tensorflow::SessionOptions()));
+ tensorflow::NewSession(session_options));
+
TF_ASSERT_OK(session->Create(graph));
std::vector<Tensor> unfused_tensors;
@@ -698,12 +713,15 @@
Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
image.flat<T>() = image.flat<T>().setRandom();
+ // Add some negative values to filter to properly test Relu.
Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
filter.flat<T>() = filter.flat<T>().setRandom();
+ filter.flat<T>() -= filter.flat<T>().constant(static_cast<T>(0.5f));
const int bias_size = filter_count;
Tensor bias(dtype, {bias_size});
bias.flat<T>() = bias.flat<T>().setRandom();
+ bias.flat<T>() += bias.flat<T>().constant(static_cast<T>(0.5f));
Tensor conv_2d;
Tensor fused_conv_2d;
@@ -714,7 +732,14 @@
ASSERT_EQ(conv_2d.dtype(), fused_conv_2d.dtype());
ASSERT_EQ(conv_2d.shape(), fused_conv_2d.shape());
- test::ExpectTensorNear<T>(conv_2d, fused_conv_2d, 1e-5);
+ // NOTE(ezhulenev): When filter size is equal to the input image size, we
+ // effectevily do element-wise product and full sum reduction, and these
+ // operations intoroduce higher than "normal" numerical errors.
+ if (image_width == filter_size && image_height == filter_size) {
+ test::ExpectTensorNear<T>(conv_2d, fused_conv_2d, 1e-3);
+ } else {
+ test::ExpectClose(conv_2d, fused_conv_2d);
+ }
}
void VerifyFusedBatchNormTensorsNear(int depth, int image_width,
@@ -727,8 +752,10 @@
Tensor image(dtype, {image_batch_count, image_height, image_width, depth});
image.flat<T>() = image.flat<T>().setRandom();
+ // Add some negative values to filter to properly test Relu.
Tensor filter(dtype, {filter_size, filter_size, depth, filter_count});
filter.flat<T>() = filter.flat<T>().setRandom();
+ filter.flat<T>() -= filter.flat<T>().constant(static_cast<T>(0.5f));
const int scale_size = filter_count;
@@ -754,7 +781,14 @@
ASSERT_EQ(conv_2d.dtype(), fused_conv_2d.dtype());
ASSERT_EQ(conv_2d.shape(), fused_conv_2d.shape());
- test::ExpectTensorNear<T>(conv_2d, fused_conv_2d, 1e-3);
+ // NOTE(ezhulenev): When filter size is equal to the input image size, we
+ // effectevily do element-wise product and full sum reduction, and these
+ // operations intoroduce higher than "normal" numerical errors.
+ if (image_width == filter_size && image_height == filter_size) {
+ test::ExpectTensorNear<T>(conv_2d, fused_conv_2d, 1e-3);
+ } else {
+ test::ExpectClose(conv_2d, fused_conv_2d);
+ }
}
// Verifies that computing Conv2D+BiasAdd in a graph is identical to
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 3f7aa0d..313def9 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -449,6 +449,27 @@
enum { Cost = 4 * NumTraits<Scalar>::AddCost, PacketAccess = false };
};
+template <typename Scalar>
+struct scalar_round_up_op {
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x) const {
+ EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
+ NUMERIC_TYPE_MUST_BE_REAL)
+
+ Scalar round_val = Eigen::numext::floor(x);
+ const Scalar fraction = x - round_val;
+ if (fraction >= Scalar(.5)) {
+ round_val += Scalar(1.0);
+ }
+ return round_val;
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<scalar_round_up_op<Scalar>> {
+ enum { Cost = 4 * NumTraits<Scalar>::AddCost, PacketAccess = false };
+};
+
#undef ENABLE_FLOAT_EQUALITY_WARNING
#undef DISABLE_FLOAT_EQUALITY_WARNING
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index b7ccf5f..f1eeda2 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -600,6 +600,7 @@
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/memory",
],
)
@@ -620,6 +621,10 @@
name = "optional_ops",
srcs = ["optional_ops.cc"],
hdrs = ["optional_ops.h"],
+ gpu_srcs = [
+ "optional_ops.cu.cc",
+ "optional_ops.h",
+ ],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
@@ -627,6 +632,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "//third_party/eigen3",
],
)
@@ -663,18 +669,6 @@
)
tf_kernel_library(
- name = "matching_files_dataset_op",
- srcs = ["matching_files_dataset_op.cc"],
- deps = [
- ":dataset",
- "//tensorflow/core:dataset_ops_op_lib",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- ],
-)
-
-tf_kernel_library(
name = "model_dataset_op",
srcs = ["model_dataset_op.cc"],
deps = [
@@ -718,7 +712,6 @@
":map_and_batch_dataset_op",
":map_dataset_op",
":map_defun_op",
- ":matching_files_dataset_op",
":model_dataset_op",
":multi_device_iterator_ops",
":optimize_dataset_op",
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index 1a18864..958c42a 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -158,6 +158,18 @@
)
tf_kernel_library(
+ name = "matching_files_dataset_op",
+ srcs = ["matching_files_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels/data:dataset",
+ ],
+)
+
+tf_kernel_library(
name = "dataset_kernels",
deps = [
":assert_next_dataset_op",
@@ -166,6 +178,7 @@
":ignore_errors_dataset_op",
":indexed_dataset",
":lmdb_dataset_op",
+ ":matching_files_dataset_op",
":non_serializable_dataset_op",
":numa_map_and_batch_dataset_op",
":prefetching_kernels",
diff --git a/tensorflow/core/kernels/data/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc
similarity index 98%
rename from tensorflow/core/kernels/data/matching_files_dataset_op.cc
rename to tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc
index d36b9e7..aa27a13 100644
--- a/tensorflow/core/kernels/data/matching_files_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc
@@ -366,8 +366,9 @@
};
};
-REGISTER_KERNEL_BUILDER(Name("MatchingFilesDataset").Device(DEVICE_CPU),
- MatchingFilesDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalMatchingFilesDataset").Device(DEVICE_CPU),
+ MatchingFilesDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
index ab21dfc..335f2b7 100644
--- a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -13,10 +13,12 @@
limitations under the License.
==============================================================================*/
+#include <memory>
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/util/ptr_util.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
@@ -187,20 +189,135 @@
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ return dataset()->input_->MakeIterator(
+ IteratorContext(CreateParams(ctx)), prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
+ out_tensors, end_of_sequence);
+ }
+
+ protected:
+ std::shared_ptr<model::Node> CreateNode(
+ IteratorContext* ctx, model::Node::Args args) const override {
+ return model::MakeKnownRatioNode(std::move(args),
+ /*ratio=*/1);
+ }
+
+ private:
+ IteratorContext::Params CreateParams(IteratorContext* ctx) {
ThreadPoolResource* pool = dataset()->threadpool_;
IteratorContext::Params params(ctx);
params.runner = [pool](std::function<void()> c) {
pool->Schedule(std::move(c));
};
params.runner_threadpool_size = pool->NumThreads();
- IteratorContext iter_ctx(params);
- return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
+ return params;
+ }
+
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* const input_;
+ const Tensor resource_handle_;
+ ThreadPoolResource* const threadpool_;
+ };
+};
+
+class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit MaxIntraOpParallelismDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 max_intra_op_parallelism;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "max_intra_op_parallelism",
+ &max_intra_op_parallelism));
+ OP_REQUIRES(
+ ctx, max_intra_op_parallelism >= 0,
+ errors::InvalidArgument("`max_intra_op_parallelism` must be >= 0"));
+ *output = new Dataset(ctx, input, max_intra_op_parallelism);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ int max_intra_op_parallelism)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ max_intra_op_parallelism_(max_intra_op_parallelism) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::MaxIntraOpParallelism")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "MaxIntraOpParallelismDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* max_intra_op_parallelism_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(max_intra_op_parallelism_,
+ &max_intra_op_parallelism_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, max_intra_op_parallelism_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ IteratorContext::Params params(ctx);
+ auto max_parallelism = dataset()->max_intra_op_parallelism_;
+ params.runner = std::bind(
+ [max_parallelism](
+ const std::function<void(std::function<void()>)>& runner,
+ std::function<void()> fn) {
+ std::function<void()> scoped_fn = std::bind(
+ [max_parallelism](const std::function<void()>& fn) {
+ ScopedPerThreadMaxParallelism scope(max_parallelism);
+ fn();
+ },
+ std::move(fn));
+ (runner)(std::move(scoped_fn));
+ },
+ std::move(*ctx->runner()), std::placeholders::_1);
+ return input_impl_->GetNext(IteratorContext{std::move(params)},
+ out_tensors, end_of_sequence);
}
protected:
@@ -215,11 +332,116 @@
};
const DatasetBase* const input_;
- const Tensor resource_handle_;
- ThreadPoolResource* const threadpool_;
+ const int max_intra_op_parallelism_;
};
};
+class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit PrivateThreadPoolDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 num_threads;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "num_threads", &num_threads));
+ OP_REQUIRES(ctx, num_threads >= 1,
+ errors::InvalidArgument("`num_threads` must be >= 1"));
+ *output = new Dataset(ctx, input, num_threads);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int num_threads)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ num_threads_(num_threads) {
+ thread_pool_ = MakeUnique<thread::ThreadPool>(
+ ctx->env(), ThreadOptions{}, "data_private_threadpool", num_threads,
+ /*low_latency_hint=*/false);
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::PrivateThreadPool")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "PrivateThreadPoolDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* num_threads_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(num_threads_, &num_threads_node));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, num_threads_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ thread::ThreadPool* pool = dataset()->thread_pool_.get();
+ IteratorContext::Params params(ctx);
+ params.runner = [pool](std::function<void()> c) {
+ pool->Schedule(std::move(c));
+ };
+ params.runner_threadpool_size = dataset()->num_threads_;
+ return input_impl_->GetNext(IteratorContext{std::move(params)},
+ out_tensors, end_of_sequence);
+ }
+
+ protected:
+ std::shared_ptr<model::Node> CreateNode(
+ IteratorContext* ctx, model::Node::Args args) const override {
+ return model::MakeKnownRatioNode(std::move(args),
+ /*ratio=*/1);
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* const input_;
+ const int64 num_threads_;
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalMaxIntraOpParallelismDataset").Device(DEVICE_CPU),
+ MaxIntraOpParallelismDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalPrivateThreadPoolDataset").Device(DEVICE_CPU),
+ PrivateThreadPoolDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
ThreadPoolHandleOp);
REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 93999dc..98b6745 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "tensorflow/core/kernels/data/iterator_ops.h"
+#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
@@ -545,10 +546,9 @@
// in its resource manager. The existing device will outlive the
// IteratorResource, because we are storing the IteratorResource
// in that device's resource manager.
- Device* wrapped_device = RenamedDevice::NewRenamedDevice(
+ *device_mgr = absl::make_unique<DeviceMgr>(RenamedDevice::NewRenamedDevice(
ctx->device()->name(), down_cast<Device*>(ctx->device()),
- false /* owns_underlying */, false /* isolate_session_state */);
- device_mgr->reset(new DeviceMgr({wrapped_device}));
+ false /* owns_underlying */, false /* isolate_session_state */));
flib_def->reset(new FunctionLibraryDefinition(
*ctx->function_library()->GetFunctionLibraryDefinition()));
pflr->reset(new ProcessFunctionLibraryRuntime(
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 72a401e..f389ff1 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -566,7 +566,10 @@
RecordStart(ctx.get());
auto stop_cleanup =
gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
- new_calls.reserve(num_parallel_calls_->value);
+ {
+ tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
+ new_calls.reserve(num_parallel_calls_->value);
+ }
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
int64 num_parallel_calls = num_parallel_calls_->value;
int64 max_batch_results =
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index f90dcb9..f5bb35d 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -29,6 +29,7 @@
#include "tensorflow/core/grappler/grappler_item_builder.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
@@ -56,8 +57,13 @@
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
Dataset* dataset =
new Dataset(ctx, input, optimizations, output_types_, output_shapes_);
- OP_REQUIRES_OK(ctx, dataset->Optimize(ctx));
- *output = dataset;
+ Status s = dataset->Optimize(ctx);
+ if (s.ok()) {
+ *output = dataset;
+ } else {
+ dataset->Unref();
+ OP_REQUIRES_OK(ctx, s);
+ }
}
private:
@@ -68,6 +74,7 @@
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)),
+ optimized_input_(nullptr),
input_(input),
optimizations_(optimizations),
output_types_(output_types),
@@ -77,7 +84,9 @@
~Dataset() override {
input_->Unref();
- optimized_input_->Unref();
+ if (optimized_input_) {
+ optimized_input_->Unref();
+ }
}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index 2ab5c83..bee857f 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -22,75 +22,6 @@
namespace tensorflow {
namespace data {
namespace {
-const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
-
-// An `OptionalVariant` can represent either an "actual value" (a tuple of
-// tensors) or "none", and may be stored in a DT_VARIANT tensor.
-class OptionalVariant {
- public:
- // Create an `OptionalVariant` with no actual value.
- OptionalVariant() : values_(nullptr) {}
-
- // Create an `OptionalVariant` with the actual value given by the tuple of
- // tensors in `values`.
- explicit OptionalVariant(std::vector<Tensor> values)
- : values_(new std::vector<Tensor>(std::move(values))) {}
-
- OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
-
- // Returns true if `this` represents an actual value.
- bool has_value() const { return values_ != nullptr; }
-
- // REQUIRES: `this->has_value()` must be true.
- const std::vector<Tensor>& get_values() const {
- CHECK(values_) << "Tried to get values from an empty OptionalVariant";
- return *values_;
- }
-
- // Implementations of the necessary methods for using `OptionalVariant`
- // objects in DT_VARIANT tensors.
- string TypeName() const { return kOptionalVariantTypeName; }
- void Encode(VariantTensorData* data) const {
- data->set_metadata(values_ != nullptr);
- if (values_ != nullptr) {
- for (const auto& t : *values_) {
- *(data->add_tensors()) = t;
- }
- }
- }
-
- bool Decode(const VariantTensorData& data) {
- if (data.type_name() != TypeName()) {
- return false;
- }
- bool has_value = false;
- if (!data.get_metadata(&has_value)) {
- return false;
- }
- if (has_value) {
- values_.reset(new std::vector<Tensor>(data.tensors()));
- } else {
- values_.reset();
- }
- return true;
- }
-
- string DebugString() const {
- if (values_) {
- return strings::StrCat("OptionalVariant<", "values: (",
- str_util::Join(*values_, ", ",
- [](string* s, const Tensor& elem) {
- *s = elem.DebugString();
- }),
- ")>");
- } else {
- return strings::StrCat("OptionalVariant<None>");
- }
- }
-
- private:
- std::shared_ptr<const std::vector<Tensor>> values_;
-};
class OptionalNoneOp : public OpKernel {
public:
@@ -143,6 +74,12 @@
explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES(
+ ctx, output_shapes_.size() == output_types_.size(),
+ errors::InvalidArgument(
+ "output_types and output_shapes must be same length, got:\n",
+ "output_types: ", output_types_.size(), "\n",
+ "output_shapes: ", output_shapes_.size()));
}
void Compute(OpKernelContext* ctx) override {
@@ -162,6 +99,10 @@
ctx, optional->has_value(),
errors::InvalidArgument("The given optional does not have a value."));
const auto& components = optional->get_values();
+ OP_REQUIRES(ctx, components.size() == output_types_.size(),
+ errors::InvalidArgument(
+ "The given optional has ", components.size(),
+ " components, expected ", output_types_.size()));
for (int i = 0; i < components.size(); ++i) {
OP_REQUIRES(
ctx, components[i].dtype() == output_types_[i],
@@ -213,15 +154,7 @@
std::vector<Tensor> to_values;
to_values.reserve(from_values.size());
for (const Tensor& t : from_values) {
- if (t.dtype() == DT_VARIANT) {
- // TODO(b/116349787): Implement support for nested variants.
- return errors::Unimplemented(
- "Support for copying nested variants to device has not yet been "
- "implemented.");
- }
- }
- for (const Tensor& t : from_values) {
- if (DMAHelper::CanUseDMA(&t)) {
+ if (DMAHelper::CanUseDMA(&t) || t.dtype() == DT_VARIANT) {
Tensor tmp(t.dtype());
TF_RETURN_IF_ERROR(copy(t, &tmp));
to_values.push_back(std::move(tmp));
@@ -272,5 +205,20 @@
return Status::OK();
}
+REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
+ DEVICE_CPU, OptionalVariant,
+ OptionalZerosLike<CPUDevice>);
+
+REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
+ OptionalVariant,
+ OptionalBinaryAdd<CPUDevice>);
+
+Status OptionalShape(const OptionalVariant& x, TensorShape* s) {
+ *s = TensorShape({});
+ return Status::OK();
+}
+
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(OptionalVariant, OptionalShape);
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cu.cc b/tensorflow/core/kernels/data/optional_ops.cu.cc
new file mode 100644
index 0000000..eb4a95a
--- /dev/null
+++ b/tensorflow/core/kernels/data/optional_ops.cu.cc
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_THREADS
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/data/optional_ops.h"
+
+#include "tensorflow/core/framework/variant_op_registry.h"
+
+namespace tensorflow {
+namespace data {
+
+REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
+ DEVICE_GPU, OptionalVariant,
+ OptionalZerosLike<GPUDevice>);
+
+REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
+ OptionalVariant,
+ OptionalBinaryAdd<GPUDevice>);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h
index 2cbf293..ef14e84 100644
--- a/tensorflow/core/kernels/data/optional_ops.h
+++ b/tensorflow/core/kernels/data/optional_ops.h
@@ -19,10 +19,13 @@
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/util/tensor_ops_util.h"
namespace tensorflow {
namespace data {
+const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
+
// Stores a DT_VARIANT value representing an Optional with the given value
// in the `output_index`^th output of the given kernel execution context.
Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
@@ -32,6 +35,122 @@
// in the `output_index`^th output of the given kernel execution context.
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
+// An `OptionalVariant` can represent either an "actual value" (a tuple of
+// tensors) or "none", and may be stored in a DT_VARIANT tensor.
+class OptionalVariant {
+ public:
+ // Create an `OptionalVariant` with no actual value.
+ OptionalVariant() : values_(nullptr) {}
+
+ // Create an `OptionalVariant` with the actual value given by the tuple of
+ // tensors in `values`.
+ explicit OptionalVariant(std::vector<Tensor> values)
+ : values_(new std::vector<Tensor>(std::move(values))) {}
+
+ OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
+
+ // Returns true if `this` represents an actual value.
+ bool has_value() const { return values_ != nullptr; }
+
+ // REQUIRES: `this->has_value()` must be true.
+ const std::vector<Tensor>& get_values() const {
+ DCHECK(values_) << "Tried to get values from an empty OptionalVariant";
+ return *values_;
+ }
+
+ // Implementations of the necessary methods for using `OptionalVariant`
+ // objects in DT_VARIANT tensors.
+ string TypeName() const { return kOptionalVariantTypeName; }
+ void Encode(VariantTensorData* data) const {
+ data->set_metadata(values_ != nullptr);
+ if (values_ != nullptr) {
+ for (const auto& t : *values_) {
+ *(data->add_tensors()) = t;
+ }
+ }
+ }
+
+ bool Decode(const VariantTensorData& data) {
+ if (data.type_name() != TypeName()) {
+ return false;
+ }
+ bool has_value = false;
+ if (!data.get_metadata(&has_value)) {
+ return false;
+ }
+ if (has_value) {
+ values_.reset(new std::vector<Tensor>(data.tensors()));
+ } else {
+ values_.reset();
+ }
+ return true;
+ }
+
+ string DebugString() const {
+ if (values_) {
+ return strings::StrCat("OptionalVariant<", "values: (",
+ str_util::Join(*values_, ", ",
+ [](string* s, const Tensor& elem) {
+ *s = elem.DebugString();
+ }),
+ ")>");
+ } else {
+ return strings::StrCat("OptionalVariant<None>");
+ }
+ }
+
+ private:
+ std::shared_ptr<const std::vector<Tensor>> values_;
+};
+
+template <typename Device>
+Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x,
+ OptionalVariant* y) {
+ if (!x.has_value()) {
+ *y = x;
+ return Status::OK();
+ }
+ std::vector<Tensor> zero_tensors;
+ for (const Tensor& tensor : x.get_values()) {
+ Tensor zero_t;
+ TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(ctx, tensor, &zero_t));
+ zero_tensors.push_back(std::move(zero_t));
+ }
+ *y = OptionalVariant(zero_tensors);
+ return Status::OK();
+}
+
+template <typename Device>
+Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a,
+ const OptionalVariant& b, OptionalVariant* out) {
+ // TODO(skyewm): should adding a value to a non-value be a no-op instead?
+ if (a.has_value() != b.has_value()) {
+ return errors::InvalidArgument(
+ "Cannot add optionals because one has a value and the other doesn't.");
+ }
+ if (!a.has_value()) {
+ *out = a;
+ return Status::OK();
+ }
+ if (a.get_values().size() != b.get_values().size()) {
+ return errors::InvalidArgument(
+ "Cannot add optionals because they have different numbers of "
+ "components (",
+ a.get_values().size(), " vs. ", b.get_values().size(), ").");
+ }
+ std::vector<Tensor> out_tensors;
+ for (int i = 0; i < a.get_values().size(); ++i) {
+ const Tensor& a_tensor = a.get_values()[i];
+ const Tensor& b_tensor = b.get_values()[i];
+ Tensor out_tensor;
+ TF_RETURN_IF_ERROR(
+ BinaryAddTensors<Device>(ctx, a_tensor, b_tensor, &out_tensor));
+ out_tensors.push_back(std::move(out_tensor));
+ }
+ *out = OptionalVariant(out_tensors);
+ return Status::OK();
+}
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 985e197..23e6adc 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -1241,7 +1241,7 @@
element_in_use_(params.dataset->cycle_length_, false),
thread_pool_(new thread::ThreadPool(
Env::Default(), ThreadOptions(),
- "tf_data_parallel_interleave_worker_pool",
+ "data_parallel_interleave_worker_pool",
dataset()->cycle_length_ /* num_threads */,
false /* low_latency_hint */)) {
std::vector<string> components =
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index ec1c923..5d6c12e 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -252,7 +252,10 @@
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
- new_calls.reserve(num_parallel_calls_->value);
+ {
+ tf_shared_lock l(*mu_); // mu_ == num_parallel_calls_->mu
+ new_calls.reserve(num_parallel_calls_->value);
+ }
auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc
index 5b084a1..89e3881 100644
--- a/tensorflow/core/kernels/data/single_threaded_executor.cc
+++ b/tensorflow/core/kernels/data/single_threaded_executor.cc
@@ -65,21 +65,28 @@
if (IsRefType(dt)) {
return errors::Unimplemented(
"Single-threaded executor does not support reference-typed "
- "edges.");
+ "edges. But saw type ",
+ DataTypeString(dt), " in outputs of node ", n->name());
}
}
if (n->IsControlFlow()) {
return errors::Unimplemented(
- "Single-threaded executor does not support control flow.");
+ "Single-threaded executor does not support control flow. But saw "
+ "control flow node ",
+ n->name());
}
if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) {
return errors::Unimplemented(
- "Single-threaded executor does not support partitioned graphs.");
+ "Single-threaded executor does not support partitioned graphs. "
+ "But saw send/recv node ",
+ n->name());
}
if (n->IsCollective()) {
return errors::Unimplemented(
- "Single-threaded executor does not support collective ops.");
+ "Single-threaded executor does not support collective ops. But "
+ "saw collective node ",
+ n->name());
}
KernelState& kernel_state = kernels_[i];
diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
index 6244e28..7bb51fb 100644
--- a/tensorflow/core/kernels/data/single_threaded_executor_test.cc
+++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
@@ -51,17 +51,17 @@
// when the test completes.
CHECK(rendez_->Unref());
delete exec_;
- delete device_;
}
// Resets executor_ with a new executor based on a graph 'gdef'.
void Create(std::unique_ptr<const Graph> graph) {
const int version = graph->versions().producer();
LocalExecutorParams params;
- params.device = device_;
+ params.device = device_.get();
params.create_kernel = [this, version](const NodeDef& ndef,
OpKernel** kernel) {
- return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+ return CreateNonCachedKernel(device_.get(), nullptr, ndef, version,
+ kernel);
};
params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel);
@@ -86,7 +86,7 @@
return exec_->Run(args);
}
- Device* device_ = nullptr;
+ std::unique_ptr<Device> device_;
Executor* exec_ = nullptr;
Executor::Args::Runner runner_;
Rendezvous* rendez_ = nullptr;
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index b32ab8b..af7f676 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -54,6 +54,8 @@
}
}
+ ~Dataset() override { input_->Unref(); }
+
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
diff --git a/tensorflow/core/kernels/deep_conv2d.cc b/tensorflow/core/kernels/deep_conv2d.cc
index f9c8f16..750c031 100644
--- a/tensorflow/core/kernels/deep_conv2d.cc
+++ b/tensorflow/core/kernels/deep_conv2d.cc
@@ -434,10 +434,9 @@
tile_spatial_size, base_filter_spatial_size, transform_matrix);
auto shard = [&ctx, &args, &transform, &base_filter_rows, &base_filter_cols,
- &num_filters_transform, &in_depth, &out_depth,
- &filter_shards_row, &filter_shards_col, &tile_spatial_size,
- &filter_in, &transform_matrix,
- &filter_out](int64 start, int64 limit) {
+ &num_filters_transform, &in_depth, &filter_shards_row,
+ &filter_shards_col, &tile_spatial_size, &filter_in,
+ &transform_matrix, &filter_out](int64 start, int64 limit) {
// Allocate buffer for pre-processed filter:
// [base_filter_rows, base_filter_cols, num_filters_transform, in_depth]
//
@@ -533,9 +532,9 @@
const int64 out_depth = args.out_depth;
const int64 num_filters = filter_shards_row * filter_shards_col * out_depth;
- auto shard = [&ctx, &packed_filters, &filter_transform_data,
- &tile_spatial_size, &in_depth, &out_depth, &filter_shards_row,
- &filter_shards_col, &num_filters](int64 start, int64 limit) {
+ auto shard = [&ctx, &packed_filters, &filter_transform_data, &in_depth,
+ &out_depth, &filter_shards_row, &filter_shards_col,
+ &num_filters](int64 start, int64 limit) {
const int64 filter_coord_stride = num_filters * in_depth;
for (int64 i = start; i < limit; ++i) {
// Allocate filter buffer [out_depth, shard_rows, shard_cols, in_depth].
@@ -788,7 +787,7 @@
const int64 shard_base = sr * filter_shards_col + sc;
const int64 out_buf_base = tile_base + out_depth_base + shard_base;
- // Calcuate output indices and outputs to drop (if needed).
+ // Calculate output indices and outputs to drop (if needed).
const int64 out_r_start =
in_r + args.pad_rows - sr * tile_stride_rows;
// NOTE: The index 't' for 'num_tiles is used in index calculation
@@ -1004,9 +1003,9 @@
out_tile_spatial_size, tile_spatial_size, output_transform_matrix);
auto shard = [&ctx, &args, &transform, &packed_filters, &in_depth,
- out_depth, tile_rows, tile_cols, out_tile_rows, out_tile_cols,
- filter_shards_row, filter_shards_col, tile_spatial_size,
- &input, &tile_transform_matrix, &output_transform_matrix,
+ out_depth, out_tile_rows, out_tile_cols, filter_shards_row,
+ filter_shards_col, tile_spatial_size, &input,
+ &tile_transform_matrix, &output_transform_matrix,
&output](int64 batch_start, int64 batch_limit) {
const int64 row_tiles =
(args.out_rows + out_tile_rows - 1) / out_tile_rows +
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 1398c87..e811968 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -175,7 +175,7 @@
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
// Holds block plus halo and filter data for blockDim.x depths.
extern __shared__ __align__(8) unsigned char shared_memory[];
- static_assert(sizeof(S) <= 8, "Insufficient alignement detected");
+ static_assert(sizeof(S) <= 8, "Insufficient alignment detected");
S* const shared_data = reinterpret_cast<S*>(shared_memory);
const int num_batches = args.batch;
@@ -459,7 +459,7 @@
assert(CanLaunchDepthwiseConv2dGPUSmall(args));
// Holds block plus halo and filter data for blockDim.z depths.
extern __shared__ __align__(8) unsigned char shared_memory[];
- static_assert(sizeof(S) <= 8, "Insufficient alignement detected");
+ static_assert(sizeof(S) <= 8, "Insufficient alignment detected");
S* const shared_data = reinterpret_cast<S*>(shared_memory);
const int num_batches = args.batch;
@@ -1176,7 +1176,7 @@
assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.z));
// Holds block plus halo and filter data for blockDim.x depths.
extern __shared__ __align__(8) unsigned char shared_memory[];
- static_assert(sizeof(S) <= 8, "Insufficient alignement detected");
+ static_assert(sizeof(S) <= 8, "Insufficient alignment detected");
S* const shared_data = reinterpret_cast<S*>(shared_memory);
const int num_batches = args.batch;
@@ -1448,7 +1448,7 @@
assert(CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, blockDim.x));
// Holds block plus halo and filter data for blockDim.z depths.
extern __shared__ __align__(8) unsigned char shared_memory[];
- static_assert(sizeof(S) <= 8, "Insufficient alignement detected");
+ static_assert(sizeof(S) <= 8, "Insufficient alignment detected");
S* const shared_data = reinterpret_cast<S*>(shared_memory);
const int num_batches = args.batch;
diff --git a/tensorflow/core/kernels/dynamic_partition_op.cc b/tensorflow/core/kernels/dynamic_partition_op.cc
index 3c988db..572d04a 100644
--- a/tensorflow/core/kernels/dynamic_partition_op.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op.cc
@@ -142,7 +142,7 @@
OP_REQUIRES(
c, FastBoundsCheck(p, num_partitions_),
errors::InvalidArgument("indices[", i,
- "] has been asynchronously overwitten and "
+ "] has been asynchronously overwritten and "
"is no longer in range!"));
auto oi = output_index[p];
OP_REQUIRES(c, FastBoundsCheck(oi, out_flat[p].dimension(0)),
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index 1f211b1..25c735d 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -56,6 +56,7 @@
//
// TODO(ezhulenev): Consolidate this part of the code with the image patch
// extraction code since they are both very similar.
+
template <typename NewDimension, Index Rows, Index Cols, typename ArgType,
typename Device, typename Scalar_, typename Index,
typename nocontract_t, typename contract_t, int Side, int packet_size,
@@ -70,6 +71,7 @@
inner_dim_reordered, Alignment> {
public:
typedef Scalar_ Scalar;
+
typedef TensorContractionInputMapper<
Scalar, Index, Side,
TensorEvaluator<
@@ -79,6 +81,7 @@
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>
Self;
+
typedef TensorContractionSubMapper<
Scalar, Index, Side,
TensorEvaluator<
@@ -88,6 +91,7 @@
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>
SubMapper;
+
typedef SubMapper VectorMapper;
typedef SubMapper LinearMapper;
typedef typename packet_traits<Scalar>::type Packet;
@@ -533,6 +537,7 @@
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>
ParentMapper;
+
typedef TensorContractionSubMapper<
Scalar, Index, Side,
TensorEvaluator<
@@ -542,21 +547,22 @@
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>
Self;
+
typedef Self LinearMapper;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
- : m_base_mapper(base_mapper),
- m_depth_offset(vert_offset),
- m_col_offset(horiz_offset) {
+ : m_depth_offset(vert_offset),
+ m_col_offset(horiz_offset),
+ m_base_mapper(base_mapper) {
m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
m_otherIndex);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
const Self& base_mapper, Index vert_offset, Index horiz_offset)
- : m_base_mapper(base_mapper.m_base_mapper),
- m_depth_offset(vert_offset + base_mapper.m_depth_offset),
- m_col_offset(horiz_offset + base_mapper.m_col_offset) {
+ : m_depth_offset(vert_offset + base_mapper.m_depth_offset),
+ m_col_offset(horiz_offset + base_mapper.m_col_offset),
+ m_base_mapper(base_mapper.m_base_mapper) {
m_base_mapper.computeBaseIndices(m_col_offset, m_rowIndex, m_colIndex,
m_otherIndex);
}
@@ -578,7 +584,6 @@
return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
j + m_col_offset);
}
-
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
loadCoeffStandard(Index i) const {
return m_base_mapper.loadCoeffStandard(i + m_depth_offset, m_rowIndex,
@@ -611,18 +616,29 @@
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
const Index max_col =
- fastPatchColStride().divide(m_depth_offset + peeled_k);
+ (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1)) /
+ fastPatchColStride();
return std::min<Index>(1 + max_col, patchCols());
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
const Index col) const {
- const Index max_row = fastPatchRowStride().divide(
- m_depth_offset + peeled_k - col * patchColStride());
+ const Index max_row = (m_depth_offset + (peeled_k == 0 ? 0 : peeled_k - 1) -
+ col * patchColStride()) /
+ fastPatchRowStride();
return std::min<Index>(1 + max_row, patchRows());
}
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index peeled_k, const Index col,
+ Index row) const {
+ const Index max_depth = m_depth_offset + peeled_k - //
+ col * patchColStride() - //
+ row * patchRowStride();
+ return std::min<Index>(max_depth, patchDepth());
+ }
+
// MaxDepth uses only the remaining number of elements in the peeled_k.
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
@@ -692,6 +708,12 @@
return r < 0 || r >= m_base_mapper.m_inputRows;
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padAnyRow(const Index first_row,
+ const Index last_row) const {
+ return m_rowIndex + first_row < 0 ||
+ m_rowIndex + last_row >= m_base_mapper.m_inputRows;
+ }
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
const Index c = m_colIndex + col;
return c < 0 || c >= m_base_mapper.m_inputCols;
@@ -738,9 +760,6 @@
}
private:
- const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
- // performs better in benchmarks.
-
Index m_depth_offset; // First row in the input matrix
Index m_col_offset; // First col in the input matrix
@@ -750,6 +769,9 @@
Index m_rowIndex;
Index m_colIndex;
Index m_otherIndex;
+
+ const ParentMapper m_base_mapper; // Keeping a copy instead of a reference
+ // performs better in benchmarks.
};
// Arrange a block of the right input matrix (in our case it's always a "virtual
@@ -1319,23 +1341,19 @@
typedef typename packet_traits<Scalar>::type Packet;
EIGEN_DONT_INLINE
- void operator()(Scalar* block, const DataMapper& rhs, StorageIndex rows,
+ void operator()(Scalar* block, const DataMapper rhs, StorageIndex rows,
StorageIndex cols) {
const bool standard_patches = !rhs.nonStandardPatches();
if (standard_patches && (rhs.patchDepth() % packet_size == 0)) {
- if (rhs.rowStride() == 1) {
- packStandardPatches<true, /*squeeze*/ true>(block, rhs, rows, cols);
- } else {
- packStandardPatches<true, /*squeeze*/ false>(block, rhs, rows, cols);
- }
+ // Single packet always belong to single patch (row, col).
+ packStandardPatches</*patch_depth_is_multiple_of_packet_size*/ true>(
+ block, rhs, rows, cols);
} else if (standard_patches) {
- if (rhs.rowStride() == 1) {
- packStandardPatches<false, /*squeeze*/ true>(block, rhs, rows, cols);
- } else {
- packStandardPatches<false, /*squeeze*/ false>(block, rhs, rows, cols);
- }
+ // Single packet can span across multiple patch rows or columns.
+ packStandardPatches</*patch_depth_is_multiple_of_packet_size*/ false>(
+ block, rhs, rows, cols);
} else {
// With non-standard patches we don't do any vectorized loads.
@@ -1357,72 +1375,64 @@
// - patch_depth_is_multiple_of_packet_size=true: We are guaranteed to have
// depth dimension size to be a multiple of packet size, so we can skip all
// non vectorized loads and checks.
- //
- // - squeeze_reads=true: If stride along the `row` dimension is `1`, we can
- // squeeze reads along the `row` and `depth` dimensions, because they are
- // guaranteed to be contiguous in memory (two innermost dimensions).
- //
- template <bool patch_depth_is_multiple_of_packet_size, bool squeeze_reads>
+ template <bool patch_depth_is_multiple_of_packet_size>
EIGEN_ALWAYS_INLINE void packStandardPatches(Scalar* block,
- const DataMapper& rhs,
+ const DataMapper rhs,
StorageIndex rows,
StorageIndex cols) {
eigen_assert(!rhs.nonStandardPatches());
// Give vectorized_rows the name used in all other gemm_pack_rhs above.
- const Index peeled_k = (rows / packet_size) * packet_size;
+ const StorageIndex peeled_k = (rows / packet_size) * packet_size;
- const Index start_col = rhs.colOffset();
- const Index max_col = rhs.maxCol(peeled_k);
+ const StorageIndex start_col = rhs.colOffset();
+ const StorageIndex max_col = rhs.maxCol(peeled_k);
for (StorageIndex col = 0; col < cols; ++col) {
SubMapper lm = rhs.getLinearMapper(0, col);
- Index k = 0;
+ StorageIndex k = 0;
for (Index c = start_col; c < max_col; ++c) {
eigen_assert(k <= peeled_k);
- const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
- const Index max_row = rhs.maxRow(peeled_k, c);
+ const StorageIndex start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const StorageIndex max_row = rhs.maxRow(peeled_k, c);
const bool pad_col = lm.padCol(c);
// We can squeeze reads for all rows in [start_row, max_row) range.
- if (squeeze_reads && !pad_col && !lm.padRow(start_row) &&
- !lm.padRow(max_row - 1)) {
- const Index start_depth = (c == start_col) ? rhs.depthOffset() : 0;
+ if (!pad_col && !lm.padAnyRow(start_row, max_row - 1)) {
+ const StorageIndex start_depth =
+ (c == start_col) ? rhs.depthOffset() : 0;
- // Upper bound on the number of elements in the depth dimension that
- // we can squeeze read.
- const Index squeeze_length =
- (max_row - start_row) * rhs.patchDepth() - start_depth;
+ const StorageIndex max_depth =
+ std::min<StorageIndex>(start_depth + (peeled_k - k),
+ (max_row - start_row) * rhs.patchDepth());
- // Do not overshoot beyond the block size.
- const Index max_depth =
- start_depth + std::min<Index>(peeled_k - k, squeeze_length);
+ const StorageIndex base_idx = lm.baseIndex(start_row, c);
- const Index base_idx = lm.baseIndex(start_row, c);
-
- if (patch_depth_is_multiple_of_packet_size)
+ if (patch_depth_is_multiple_of_packet_size) {
+ // If patch depth is a multiple of packet size, it's guaranteed that
+ // we can process all values in depth dimension with packets.
eigen_assert((max_depth - start_depth) % packet_size == 0);
+ StorageIndex d = start_depth;
- // If patch depth is a multiple of packet size, it's guaranteed that
- // we can process all values in depth dimension with packets.
- const Index max_vectorized_depth =
- patch_depth_is_multiple_of_packet_size ? max_depth
- : max_depth - packet_size;
+ for (; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
+ block += packet_size;
+ k += packet_size;
+ }
- Index d = start_depth;
+ } else {
+ StorageIndex d = start_depth;
+ const StorageIndex vectorized_depth = max_depth - packet_size;
- // 1. Process depth dimension with vectorized instructions.
- for (; d < max_vectorized_depth; d += packet_size) {
- eigen_assert(k < peeled_k);
- internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
- block += packet_size;
- k += packet_size;
- }
-
- // 2. Finish with coefficients.
- if (!patch_depth_is_multiple_of_packet_size) {
+ for (; d <= vectorized_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ internal::pstoreu(block, rhs.packetNoPadding(d, base_idx));
+ block += packet_size;
+ k += packet_size;
+ }
for (; d < max_depth; d++) {
eigen_assert(k < peeled_k);
*block = rhs.coeffNoPadding(d, base_idx);
@@ -1437,39 +1447,43 @@
// If we are not allowed to squeeze reads along the `row` and `depth`
// dimensions, we must process rows one by one.
- for (Index r = start_row; r < max_row; ++r) {
+ for (StorageIndex r = start_row; r < max_row; ++r) {
eigen_assert(k <= peeled_k);
- const Index start_depth =
+ const StorageIndex start_depth =
((c == start_col) && (r == start_row)) ? rhs.depthOffset() : 0;
- const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ const StorageIndex max_depth =
+ rhs.maxDepth(peeled_k - k, start_depth);
const bool pad = pad_col || lm.padRow(r);
- const Index base_idx = lm.baseIndex(r, c);
+ const StorageIndex base_idx = lm.baseIndex(r, c);
- if (patch_depth_is_multiple_of_packet_size)
+ if (patch_depth_is_multiple_of_packet_size) {
+ // If patch depth is a multiple of packet size, it's guaranteed that
+ // we can process all values in depth dimension with packets.
eigen_assert((max_depth - start_depth) % packet_size == 0);
+ StorageIndex d = start_depth;
- // If patch depth is a multiple of packet size, it's guaranteed that
- // we can process all values in depth dimension with packets.
- const Index max_vectorized_depth =
- patch_depth_is_multiple_of_packet_size ? max_depth
- : max_depth - packet_size;
+ for (; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ const Packet p = pad ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, base_idx);
+ internal::pstoreu(block, p);
+ block += packet_size;
+ k += packet_size;
+ }
- Index d = start_depth;
-
- // 1. Process depth dimension with vectorized instructions.
- for (; d < max_vectorized_depth; d += packet_size) {
- eigen_assert(k < peeled_k);
- const Packet p = pad ? pset1<Packet>(Scalar(0))
- : rhs.packetNoPadding(d, base_idx);
- internal::pstoreu(block, p);
- block += packet_size;
- k += packet_size;
- }
-
- // 2. Finish with coefficients.
- if (!patch_depth_is_multiple_of_packet_size) {
+ } else {
+ const StorageIndex max_vectorized_depth = max_depth - packet_size;
+ StorageIndex d = start_depth;
+ for (; d < max_vectorized_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ const Packet p = pad ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, base_idx);
+ internal::pstoreu(block, p);
+ block += packet_size;
+ k += packet_size;
+ }
for (; d < max_depth; d++) {
eigen_assert(k < peeled_k);
*block = pad ? Scalar(0) : rhs.coeffNoPadding(d, base_idx);
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc
index 8219fc9..22f71d6 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions_test.cc
@@ -1380,7 +1380,12 @@
/* Filter (kernel) dimensions: */
int filter_count, int filter_cols, int filter_rows,
/* Input strides: */
- int col_strides, int row_strides) {
+ int col_strides, int row_strides,
+ /* Block dimensions: */
+ Index block_rows, Index block_cols) {
+ // Set random seed for benchmark repeatability.
+ srand(12345);
+
tensorflow::testing::UseRealTime();
tensorflow::testing::StopTiming();
@@ -1508,10 +1513,6 @@
PackRhsImpl pack_rhs;
- // This is the typical size of the rhs block used in Tensor contractions.
- const Index default_depth = 320; // must be multiple of 8
- const Index default_cols = 280;
-
const Index packed_total_size = input_dims.TotalSize();
tensorflow::testing::StartTiming();
@@ -1520,11 +1521,14 @@
num_inputs == 1 ? 1 : internal::random<int>(0, num_inputs - 1);
// Depth offset must be a multiple of 8 (float packet size with AVX2).
- Index depth_offset = (internal::random<Index>(0, patch_size - 10) / 8) * 8;
+ Index depth_offset =
+ (patch_size > block_rows)
+ ? (internal::random<Index>(0, patch_size - 10) / 8) * 8
+ : 0;
Index col_offset = internal::random<Index>(0, num_patches - 10);
- Index depth = std::min(default_depth, patch_size - depth_offset);
- Index cols = std::min(default_cols, num_patches - col_offset);
+ Index depth = std::min(block_rows, patch_size - depth_offset);
+ Index cols = std::min(block_cols, num_patches - col_offset);
// Write packed data to random memory location to emulate cold caches.
Index packed_size = depth * cols;
@@ -1538,20 +1542,37 @@
tensorflow::testing::StopTiming();
std::ostringstream stringStream;
- stringStream << "patch: depth=" << patch_depth << " rows=" << patch_rows
- << " cols=" << patch_cols << " num_patches=" << num_patches
+ stringStream << "patch: " << patch_rows << "x" << patch_cols << " D"
+ << patch_depth << "; num_patches=" << num_patches
<< " patch_size=" << patch_size << " num_inputs=" << num_inputs;
tensorflow::testing::SetLabel(stringStream.str());
}
-#define BM_NAME(prefix, N, H, W, C, FC, FH, FW, SH, SW) \
- BM_##prefix##_##N##_##H##x##W##_IC##C##_FC##FC##_##FH##x##FW##_s##SH##x##SW
+// -------------------------------------------------------------------------- //
+// Macro argumentnames:
+// N: batch size
+// H: height
+// W: width
+// C: input channels
+// FC: filter channles
+// FH: filter height
+// SH: stride in height dimensions
+// SW: stride in width dimensions
+// BR: block rows
+// BC: block cols
-#define BM_PackRhs(N, H, W, C, FC, FH, FW, SH, SW) \
- static void BM_NAME(PackRhs, N, H, W, C, FC, FH, FW, SH, SW)(int iters) { \
- PackRhsHelper(iters, N, H, W, C, FC, FH, FW, SH, SW); \
- } \
- BENCHMARK(BM_NAME(PackRhs, N, H, W, C, FC, FH, FW, SH, SW))
+#define BM_CONCAT(a, b) a##b
+
+#define BM_NAME(prefix, N, H, W, C, FC, FH, FW, SH, SW, BR, BC) \
+ BM_CONCAT(BM_##prefix##_##N##_##H##x##W##_IC##C##_FC##FC##_##FH##x##FW, \
+ _s##SH##x##SW##_B##BR##x##BC)
+
+#define BM_PackRhs(N, H, W, C, FC, FH, FW, SH, SW, BR, BC) \
+ static void BM_NAME(PackRhs, N, H, W, C, FC, FH, FW, SH, SW, BR, \
+ BC)(int iters) { \
+ PackRhsHelper(iters, N, H, W, C, FC, FH, FW, SH, SW, BR, BC); \
+ } \
+ BENCHMARK(BM_NAME(PackRhs, N, H, W, C, FC, FH, FW, SH, SW, BR, BC))
// Number of input channel (input depth) it equal to the number of patch
// channels (patch depth).
@@ -1563,14 +1584,16 @@
/*channels*/ 32, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
- /*stride*/ 1, 1);
+ /*stride*/ 1, 1, //
+ /*block*/ 256, 56);
BM_PackRhs(/*batch*/ 32, //
/*image*/ 64, 64, //
/*channels*/ 32, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
- /*stride*/ 2, 2);
+ /*stride*/ 2, 2, //
+ /*block*/ 256, 56);
// Slow path: input channel dimension is not the multiple of the packet size.
BM_PackRhs(/*batch*/ 32, //
@@ -1578,12 +1601,48 @@
/*channels*/ 30, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
- /*stride*/ 1, 1);
+ /*stride*/ 1, 1, //
+ /*block*/ 256, 56);
BM_PackRhs(/*batch*/ 32, //
/*image*/ 64, 64, //
/*channels*/ 30, //
/*num_filters*/ 64, //
/*filter*/ 5, 5, //
- /*stride*/ 2, 2);
+ /*stride*/ 2, 2, //
+ /*block*/ 256, 56);
+
+// Slow path with input channel dimension smaller than the packet size.
+BM_PackRhs(/*batch*/ 32, //
+ /*image*/ 256, 256, //
+ /*channels*/ 4, //
+ /*num_filters*/ 16, //
+ /*filter*/ 8, 8, //
+ /*stride*/ 1, 1, //
+ /*block*/ 256, 56);
+
+BM_PackRhs(/*batch*/ 32, //
+ /*image*/ 256, 256, //
+ /*channels*/ 4, //
+ /*num_filters*/ 16, //
+ /*filter*/ 8, 8, //
+ /*stride*/ 2, 4, //
+ /*block*/ 256, 56);
+
+// Short and wide block with small input channel dimension.
+BM_PackRhs(/*batch*/ 32, //
+ /*image*/ 64, 64, //
+ /*channels*/ 4, //
+ /*num_filters*/ 16, //
+ /*filter*/ 3, 3, //
+ /*stride*/ 1, 1, //
+ /*block*/ 36, 432);
+
+BM_PackRhs(/*batch*/ 32, //
+ /*image*/ 64, 64, //
+ /*channels*/ 4, //
+ /*num_filters*/ 16, //
+ /*filter*/ 3, 3, //
+ /*stride*/ 2, 2, //
+ /*block*/ 36, 432);
} // namespace Eigen
diff --git a/tensorflow/core/kernels/fractional_avg_pool_op.cc b/tensorflow/core/kernels/fractional_avg_pool_op.cc
index 135d002..6123447 100644
--- a/tensorflow/core/kernels/fractional_avg_pool_op.cc
+++ b/tensorflow/core/kernels/fractional_avg_pool_op.cc
@@ -223,7 +223,7 @@
// Once we figure out the original contributors, we just need to evenly
// divide the value of this element among these contributors.
//
- // Internally, we divide the out_backprop tensor and store it in a temparary
+ // Internally, we divide the out_backprop tensor and store it in a temporary
// tensor of double type. And cast it to the corresponding type.
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index d89f159..dbd3bb0 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -248,7 +248,7 @@
Tensor* saved_inv_var, TensorFormat tensor_format,
bool is_training) {
auto* stream = context->op_device_context()->stream();
- OP_REQUIRES(context, stream, errors::Internal("No GPU stream avalible"));
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
const int64 channels = GetTensorDim(x, tensor_format, 'C');
@@ -389,7 +389,7 @@
Tensor* scale_backprop, Tensor* offset_backprop,
TensorFormat tensor_format) {
auto* stream = context->op_device_context()->stream();
- OP_REQUIRES(context, stream, errors::Internal("No GPU stream avalible"));
+ OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
const int64 channels = GetTensorDim(x, tensor_format, 'C');
diff --git a/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc b/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc
index a8f07f4..b8d779f 100644
--- a/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/encode_base64_fuzz.cc
@@ -19,7 +19,7 @@
namespace tensorflow {
namespace fuzzing {
-class FuzzEncodeBase64 : public FuzzSession {
+class FuzzEncodeBase64 : public FuzzStringInputOp {
SINGLE_INPUT_OP_BUILDER(DT_STRING, EncodeBase64);
};
diff --git a/tensorflow/core/kernels/fuzzing/fuzz_session.h b/tensorflow/core/kernels/fuzzing/fuzz_session.h
index 9777be1..57d562d 100644
--- a/tensorflow/core/kernels/fuzzing/fuzz_session.h
+++ b/tensorflow/core/kernels/fuzzing/fuzz_session.h
@@ -72,11 +72,11 @@
// By convention, the graph should have inputs named "input1", ...
// "inputN", and one output node, named "output".
// Users of FuzzSession should override this method to create their graph.
- virtual void BuildGraph(const Scope& scope) {}
+ virtual void BuildGraph(const Scope& scope) = 0;
// Implements the logic that converts an opaque byte buffer
// from the fuzzer to Tensor inputs to the graph. Users must override.
- virtual void FuzzImpl(const uint8_t* data, size_t size) {}
+ virtual void FuzzImpl(const uint8_t* data, size_t size) = 0;
// Initializes the FuzzSession. Not safe for multithreading.
// Separate init function because the call to virtual BuildGraphDef
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index c2591f5..d4adc06 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -30,6 +30,7 @@
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/tensor_ops_util.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
@@ -357,40 +358,10 @@
for (int i = 0; i < a.tensors.size(); ++i) {
const Tensor& a_tensor = a.tensors[i];
const Tensor& b_tensor = b.tensors[i];
- if (a_tensor.dtype() == DT_INVALID) {
- out->tensors.push_back(b_tensor);
- continue;
- }
- if (b_tensor.dtype() == DT_INVALID) {
- out->tensors.push_back(a_tensor);
- continue;
- }
- if (a_tensor.shape() != b_tensor.shape()) {
- // TODO(apassos) support broadcasting additions here?
- return errors::InvalidArgument(
- "Trying to add two tensors with incompatible element shapes. "
- "One is ",
- a_tensor.shape().DebugString(), " and the other is ",
- b_tensor.shape().DebugString(), " in position ", i);
- }
Tensor out_tensor;
TF_RETURN_IF_ERROR(
- c->allocate_temp(a_tensor.dtype(), a_tensor.shape(), &out_tensor));
+ BinaryAddTensors<Device>(c, a_tensor, b_tensor, &out_tensor));
out->tensors.push_back(out_tensor);
- switch (out_tensor.dtype()) {
-#define DTYPE_CASE(dtype) \
- case DataTypeToEnum<dtype>::value: \
- out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
- a_tensor.flat<dtype>() + b_tensor.flat<dtype>(); \
- break;
-
- TF_CALL_NUMBER_TYPES(DTYPE_CASE)
-
-#undef DTYPE_CASE
- default:
- return errors::InvalidArgument("Trying to add unsupported dtype ",
- out_tensor.dtype());
- }
}
return Status::OK();
}
@@ -403,46 +374,7 @@
y->tensors.reserve(x.tensors.size());
for (const Tensor& t : x.tensors) {
Tensor out_tensor;
- AllocatorAttributes attr;
- if (t.dtype() == DT_VARIANT) {
- attr.set_on_host(true);
- }
- TF_RETURN_IF_ERROR(
- c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
- switch (out_tensor.dtype()) {
-#define DTYPE_CASE(dtype) \
- case DataTypeToEnum<dtype>::value: \
- out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \
- out_tensor.flat<dtype>().constant(dtype(0)); \
- break;
-
- TF_CALL_POD_TYPES(DTYPE_CASE)
-
-#undef DTYPE_CASE
-
- case DT_INVALID: {
- // Uninitialized tensor in the TensorList.
- out_tensor = Tensor(DT_INVALID);
- break;
- }
- case DataTypeToEnum<Variant>::value: {
- const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
- if (inner_x == nullptr) {
- return errors::InvalidArgument("Input handle is not a list. Saw: '",
- t.scalar<Variant>()().DebugString(),
- "'");
- }
- TensorList inner_y;
- TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
- out_tensor.scalar<Variant>()() = std::move(inner_y);
- break;
- }
-
- default:
- return errors::InvalidArgument(
- "Trying to compute zeros_like for unsupported dtype ",
- DataTypeString(out_tensor.dtype()));
- }
+ TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(c, t, &out_tensor));
y->tensors.emplace_back(out_tensor);
}
return Status::OK();
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index 0c7a236..56d0340 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -384,6 +384,7 @@
int32* top_data, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
const int output_size = batch * channels * pooled_height * pooled_width;
+ if (output_size == 0) return true;
MaxPoolForwardNoMaskKernel_NCHW_VECT_C<<<
(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock,
0, d.stream()>>>(output_size, bottom_data, height, width, channels,
@@ -402,6 +403,7 @@
int64* mask, const Eigen::GpuDevice& d, bool propagate_nans) {
const int kThreadsPerBlock = 1024;
const int output_size = batch * channels * pooled_height * pooled_width;
+ if (output_size == 0) return true;
if (propagate_nans) {
MaxPoolForwardNHWC<true>
<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
@@ -430,6 +432,7 @@
const int kThreadsPerBlock = 1024;
const int bottom_size = batch * channels * height * width;
+ if (bottom_size == 0) return true;
SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
@@ -449,6 +452,7 @@
const int64* mask, const int top_offset, const int bottom_offset,
T* bottom_diff, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
+ if (input_size == 0) return true;
SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff);
MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
@@ -466,6 +470,7 @@
const int pad_l, const T* top_diff, T* bottom_diff,
const Eigen::GpuDevice& d) {
const int num_kernels = batch * channels * pooled_height * pooled_width;
+ if (num_kernels == 0) return true;
CudaLaunchConfig config = GetCudaLaunchConfig(num_kernels, d);
if (data_format == FORMAT_NHWC) {
@@ -489,6 +494,7 @@
const int output_size, const int input_size, const T* top_diff,
const int64* mask, const int top_offset, const int bottom_offset,
T* bottom_diff, const Eigen::GpuDevice& d) {
+ if (input_size == 0) return true;
CudaLaunchConfig config = GetCudaLaunchConfig(output_size, d);
MaxPoolGradBackward<<<config.block_count, config.thread_per_block, 0,
d.stream()>>>(output_size, top_diff, mask, top_offset,
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index cfab529..094129a 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -56,7 +56,7 @@
MklDnnShape src_mkl_shape;
GetMklShape(context, src_idx, &src_mkl_shape);
- // src_dims is the dimenstion of src_tensor
+ // src_dims is the dimension of src_tensor
// dim of the dst will also be same as src_dims
auto src_tf_shape = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetTfShape()
@@ -64,12 +64,12 @@
auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
auto output_dims = src_dims;
memory::format layout_type;
- // In MKL, data format passed to mkl softmax op depends on dimension of the input tensor.
- // Here "x" data format in MKL is used for 1 dim tensor, "nc" for 2 dim tensor,
- // "tnc" for 3 dim tensor, "nchw" for 4 dim tensor, and "ncdhw" for 5 dim tensor.
- // Each of the simbols has the following meaning:
- // n = batch, c = channels, t = sequence lenght, h = height,
- // w = width, d = depth
+ // In MKL, data format passed to mkl softmax op depends on dimension of
+ // the input tensor. Here "x" data format in MKL is used for 1 dim tensor,
+ // "nc" for 2 dim tensor, "tnc" for 3 dim tensor, "nchw" for 4 dim tensor,
+ // and "ncdhw" for 5 dim tensor. Each of the simbols has the following
+ // meaning: n = batch, c = channels, t = sequence length, h = height, w =
+ // width, d = depth
switch (input_dims) {
case 1:
layout_type = memory::format::x;
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 89b7449..6c90ffd 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -453,7 +453,7 @@
},
rendez, std::move(done), std::placeholders::_1);
auto* refcounted_done = new ReffedStatusCallback(std::move(callback));
- for (int i = 1; i < handles->size(); ++i) {
+ for (int i = 0; i < handles->size(); ++i) {
refcounted_done->Ref();
}
@@ -507,6 +507,7 @@
});
}
}
+ refcounted_done->Unref();
}
string UniquifyFunctionName(const FunctionLibraryDefinition* function_library,
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.cc b/tensorflow/core/kernels/quantize_and_dequantize_op.cc
index dadc15b..f13341e 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.cc
@@ -49,6 +49,21 @@
errors::InvalidArgument("num_bits is out of range: ", num_bits_,
" with signed_input_ ", signed_input_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
+
+ string round_mode_string;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string));
+ OP_REQUIRES(
+ ctx,
+ (round_mode_string == "HALF_UP" || round_mode_string == "HALF_TO_EVEN"),
+ errors::InvalidArgument("Round mode string must be "
+ "'HALF_UP' or "
+ "'HALF_TO_EVEN', is '" +
+ round_mode_string + "'"));
+ if (round_mode_string == "HALF_UP") {
+ round_mode_ = ROUND_HALF_UP;
+ } else if (round_mode_string == "HALF_TO_EVEN") {
+ round_mode_ = ROUND_HALF_TO_EVEN;
+ }
}
void Compute(OpKernelContext* ctx) override {
@@ -76,13 +91,15 @@
functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_,
- range_given_, &input_min_tensor, &input_max_tensor, output->flat<T>());
+ range_given_, &input_min_tensor, &input_max_tensor, round_mode_,
+ output->flat<T>());
}
private:
bool signed_input_;
int num_bits_;
bool range_given_;
+ QuantizerRoundMode round_mode_;
};
// Simulate quantization precision loss in a float tensor by:
@@ -135,7 +152,8 @@
functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> f;
f(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_, num_bits_val,
- range_given_, &input_min_tensor, &input_max_tensor, output->flat<T>());
+ range_given_, &input_min_tensor, &input_max_tensor, ROUND_HALF_TO_EVEN,
+ output->flat<T>());
}
private:
@@ -180,7 +198,7 @@
functor::QuantizeAndDequantizeOneScaleFunctor<Device, T> functor;
functor(ctx->eigen_device<Device>(), input.flat<T>(), signed_input_,
num_bits_, range_given_, &input_min_tensor, &input_max_tensor,
- output->flat<T>());
+ ROUND_HALF_TO_EVEN, output->flat<T>());
}
private:
@@ -198,10 +216,11 @@
void operator()(const CPUDevice& d, typename TTypes<T>::ConstVec input,
const bool signed_input, const int num_bits,
const bool range_given, Tensor* input_min_tensor,
- Tensor* input_max_tensor, typename TTypes<T>::Vec out) {
+ Tensor* input_max_tensor, QuantizerRoundMode round_mode,
+ typename TTypes<T>::Vec out) {
QuantizeAndDequantizeOneScaleImpl<CPUDevice, T>::Compute(
d, input, signed_input, num_bits, range_given, input_min_tensor,
- input_max_tensor, out);
+ input_max_tensor, round_mode, out);
}
};
} // namespace functor
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.h b/tensorflow/core/kernels/quantize_and_dequantize_op.h
index 6b0c5e5..a495e8b 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.h
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.h
@@ -22,6 +22,20 @@
#include "tensorflow/core/kernels/cwise_ops.h"
namespace tensorflow {
+
+enum QuantizerRoundMode {
+ // Round half up: if the fraction of y is exactly 0.5, then
+ // round(y) = y + 0.5
+ // E.g., -5.5 gets rounded to -5, -5.4 goes to -5,
+ // 5.4 goes to 5, and 5.5 goes to 6.
+ ROUND_HALF_UP,
+ // Round half to even: if the fraction of y is exactly 0.5, then round(y) is
+ // the nearest even integer to y.
+ // E.g., 23.5 gets rounded to 24, 24.5 gets rounded to 24, while -23.5 becomes
+ // -24, and -24.5 gets rounded to 24.
+ ROUND_HALF_TO_EVEN,
+};
+
namespace functor {
// TODO(pauldonnelly): 'signed_input' should really be called 'signed_output'.
@@ -31,15 +45,69 @@
void operator()(const Device& d, typename TTypes<T>::ConstVec input,
bool signed_input, int num_bits, bool range_given,
Tensor* input_min_tensor, Tensor* input_max_tensor,
- typename TTypes<T>::Vec out);
+ QuantizerRoundMode round_mode, typename TTypes<T>::Vec out);
};
// The implementation below runs on both CPU and GPU.
+template <typename Device, typename T, typename Func>
+void ClampScaleAndRound(const Device& d, typename TTypes<T>::ConstVec input,
+ T min_range, T max_range, T scale, T inverse_scale,
+ Func round_func, typename TTypes<T>::Vec out) {
+ out.device(d) = (input.cwiseMin(max_range).cwiseMax(min_range) * scale)
+ .unaryExpr(round_func) *
+ inverse_scale;
+}
+
+// The implementation below runs on both CPU and GPU.
+template <typename Device, typename T>
+void ClampScaleAndRound(const Device& d, typename TTypes<T>::ConstVec input,
+ T min_range, T max_range, T scale, T inverse_scale,
+ QuantizerRoundMode round_mode,
+ typename TTypes<T>::Vec out) {
+ switch (round_mode) {
+ case ROUND_HALF_TO_EVEN:
+ ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale,
+ Eigen::internal::scalar_round_op_google<T>(), out);
+ break;
+ case ROUND_HALF_UP:
+ ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale,
+ Eigen::internal::scalar_round_up_op<T>(), out);
+ break;
+ }
+}
+
+// The implementation below runs on both CPU and GPU.
+template <typename Device, typename T, typename Func>
+void ScaleAndRound(const Device& d, typename TTypes<T>::ConstVec input, T scale,
+ T inverse_scale, Func round_func,
+ typename TTypes<T>::Vec out) {
+ out.device(d) = (input * scale).unaryExpr(round_func) * inverse_scale;
+}
+
+// The implementation below runs on both CPU and GPU.
+template <typename Device, typename T>
+void ScaleAndRound(const Device& d, typename TTypes<T>::ConstVec input, T scale,
+ T inverse_scale, QuantizerRoundMode round_mode,
+ typename TTypes<T>::Vec out) {
+ switch (round_mode) {
+ case ROUND_HALF_TO_EVEN:
+ ScaleAndRound(d, input, scale, inverse_scale,
+ Eigen::internal::scalar_round_op_google<T>(), out);
+ break;
+ case ROUND_HALF_UP:
+ ScaleAndRound(d, input, scale, inverse_scale,
+ Eigen::internal::scalar_round_up_op<T>(), out);
+ break;
+ }
+}
+
+// The implementation below runs on both CPU and GPU.
template <typename Device, typename T>
struct QuantizeAndDequantizeOneScaleImpl {
static void Compute(const Device& d, typename TTypes<T>::ConstVec input,
bool signed_input, int num_bits, bool range_given,
Tensor* input_min_tensor, Tensor* input_max_tensor,
+ QuantizerRoundMode round_mode,
typename TTypes<T>::Vec out) {
T min_range;
T max_range;
@@ -89,15 +157,10 @@
// The semantics of the op does not guarantee to clamp to the specified
// min_range and max_range - because we may have changed either min_range
// or max_range.
- out.device(d) =
- (input.cwiseMin(max_range).cwiseMax(min_range) * scale)
- .unaryExpr(Eigen::internal::scalar_round_op_google<T>()) *
- inverse_scale;
+ ClampScaleAndRound(d, input, min_range, max_range, scale, inverse_scale,
+ round_mode, out);
} else {
- out.device(d) =
- (input * scale)
- .unaryExpr(Eigen::internal::scalar_round_op_google<T>()) *
- inverse_scale;
+ ScaleAndRound(d, input, scale, inverse_scale, round_mode, out);
}
}
};
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc
index 61c79cf..5745e41 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op_gpu.cu.cc
@@ -32,10 +32,10 @@
void operator()(const GPUDevice& d, typename TTypes<T>::ConstVec input,
bool signed_input, int num_bits, bool range_given,
Tensor* input_min_tensor, Tensor* input_max_tensor,
- typename TTypes<T>::Vec out) {
+ QuantizerRoundMode round_mode, typename TTypes<T>::Vec out) {
QuantizeAndDequantizeOneScaleImpl<GPUDevice, T>::Compute(
d, input, signed_input, num_bits, range_given, input_min_tensor,
- input_max_tensor, out);
+ input_max_tensor, round_mode, out);
}
};
} // end namespace functor
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
index cddabf8..b9e015c 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
@@ -101,17 +101,51 @@
.Attr("range_given", false)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
- AddInputFromArray<float>(TensorShape({6}), {-1, -0.5, 0, 0.3, 0.8, 0.555});
+ AddInputFromArray<float>(TensorShape({7}),
+ {-1, -0.5, 0, 0.3, 0.8, 0.555, 0.50390625});
AddInputFromArray<float>(TensorShape({}), {0.0}); // Min
AddInputFromArray<float>(TensorShape({}), {0.0}); // Max
- // With int8, the tensor is quantized to {-128, -64, 0, 38, 102, 71}.
+ // With int8, the tensor is quantized to {-128, -64, 0, 38, 102, 71, 64}.
// Scale is: 1/127
- // Then it is dequantized to {-1, -0.5, 0, 38.0/128, 102.0/128, 71.0/128}
+ // Then it is dequantized to {-1, -0.5, 0, 38.0/128, 102.0/128, 71.0/128, 0.5}
TF_ASSERT_OK(RunOpKernel());
- Tensor expected(allocator(), DT_FLOAT, TensorShape({6}));
- test::FillValues<float>(&expected,
- {-1, -0.5, 0, 38.0 / 128, 102.0 / 128, 71.0 / 128});
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({7}));
+ test::FillValues<float>(
+ &expected, {-1, -0.5, 0, 38.0 / 128, 102.0 / 128, 71.0 / 128, 0.5});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+
+ // Ensure that the inputs haven't been changed.
+ EXPECT_EQ(inputs_[1]->scalar<float>()(), 0.0);
+ EXPECT_EQ(inputs_[2]->scalar<float>()(), 0.0);
+}
+
+// Convert a 1D tensor with signed 8 bits and round_mode half_up.
+TEST_F(QuantizeAndDequantizeTest, Convert_1D_tensor_with_int8_round_half_up) {
+ TF_ASSERT_OK(
+ NodeDefBuilder("quantize_and_dequantize_op", "QuantizeAndDequantizeV2")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("signed_input", true)
+ .Attr("num_bits", 8)
+ .Attr("range_given", false)
+ .Attr("round_mode", "HALF_UP")
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ AddInputFromArray<float>(TensorShape({7}),
+ {-1, -0.5, 0, 0.3, 0.8, 0.555, 0.50390625});
+ AddInputFromArray<float>(TensorShape({}), {0.0}); // Min
+ AddInputFromArray<float>(TensorShape({}), {0.0}); // Max
+
+ // With int8, the tensor is quantized to {-128, -64, 0, 38, 102, 71, 65}.
+ // Scale is: 1/127
+ // Then it is dequantized to {-1, -0.5, 0, 38.0/128, 102.0/128, 71.0/128,
+ // 65.0 /128}
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({7}));
+ test::FillValues<float>(&expected, {-1, -0.5, 0, 38.0 / 128, 102.0 / 128,
+ 71.0 / 128, 65.0 / 128});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
// Ensure that the inputs haven't been changed.
@@ -162,7 +196,7 @@
.Attr("range_given", false)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
- AddInputFromArray<float>(TensorShape({6}), {-1, -0.5, 0, 0.3, 0.8, 0.555});
+ AddInputFromArray<float>(TensorShape({6}), {-1, -0.5, 0, 0.3125, 0.8, 0.555});
AddInputFromArray<float>(TensorShape({}), {0.0}); // Min
AddInputFromArray<float>(TensorShape({}), {0.0}); // Max
@@ -178,6 +212,35 @@
EXPECT_EQ(inputs_[2]->scalar<float>()(), 0.0);
}
+// Convert a 1D tensor with signed 4 bits and round_mode hafl_up.
+TEST_F(QuantizeAndDequantizeTest, Convert_1D_tensor_with_int4_round_half_up) {
+ TF_ASSERT_OK(
+ NodeDefBuilder("quantize_and_dequantize_op", "QuantizeAndDequantizeV2")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("signed_input", true)
+ .Attr("num_bits", 4)
+ .Attr("range_given", false)
+ .Attr("round_mode", "HALF_UP")
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ AddInputFromArray<float>(TensorShape({6}), {-1, -0.5, 0, 0.3125, 0.8, 0.555});
+ AddInputFromArray<float>(TensorShape({}), {0.0}); // Min
+ AddInputFromArray<float>(TensorShape({}), {0.0}); // Max
+
+ // With int4, the tensor is quantized to {-8, -4, 0, 3, 6, 4}.
+ // Scale is: 1/8
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({6}));
+ test::FillValues<float>(&expected, {-1, -0.5, 0, 0.375, 0.75, 0.5});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+
+ // Ensure that the inputs haven't been changed.
+ EXPECT_EQ(inputs_[1]->scalar<float>()(), 0.0);
+ EXPECT_EQ(inputs_[2]->scalar<float>()(), 0.0);
+}
+
// Convert a 1D tensor with signed 4 bits.
TEST_F(QuantizeAndDequantizeTest, Convert_1D_tensor_with_int4_V3) {
TF_ASSERT_OK(
@@ -237,6 +300,38 @@
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
+// Convert a 2D tensor with signed 8 bits, given range and round_mode half_up.
+TEST_F(QuantizeAndDequantizeTest,
+ Convert_2D_tensor_with_int8_range_given_round_half_up) {
+ TF_ASSERT_OK(
+ NodeDefBuilder("quantize_and_dequantize_op", "QuantizeAndDequantizeV2")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("signed_input", true)
+ .Attr("num_bits", 8)
+ .Attr("range_given", true)
+ .Attr("round_mode", "HALF_UP")
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ // Note that the last two values are saturated.
+ AddInputFromArray<float>(TensorShape({2, 4}),
+ {-0.8, -0.5, 0, 0.3, 0.8, 0.555, -2, 33});
+ AddInputFromArray<float>(TensorShape({}), {-1.0}); // Min
+ AddInputFromArray<float>(TensorShape({}), {1.0}); // Max
+
+ // Note that the range is given as [-1, 1].
+ // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128,
+ // 127}.
+ // Scale is: 1/127
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4}));
+ test::FillValues<float>(
+ &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127,
+ 70.0 / 127, -128.0 / 127, 1});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
// Convert a 2D tensor with signed 8 bits with given range.
TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given_V3) {
TF_ASSERT_OK(
@@ -293,6 +388,33 @@
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
+// Convert a 4D tensor with unsigned 8 bits, given range and round_mode half_up.
+TEST_F(QuantizeAndDequantizeTest,
+ Convert_4D_tensor_with_uint8_range_given_round_half_up) {
+ TF_ASSERT_OK(
+ NodeDefBuilder("quantize_and_dequantize_op", "QuantizeAndDequantizeV2")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("signed_input", false)
+ .Attr("num_bits", 8)
+ .Attr("range_given", true)
+ .Attr("round_mode", "HALF_UP")
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ AddInputFromArray<float>(TensorShape({2, 2, 1, 1}), {-0.5, 0, 0.3, 0.8});
+ AddInputFromArray<float>(TensorShape({}), {0.0}); // Min
+ AddInputFromArray<float>(TensorShape({}), {1.0}); // Max
+
+ // Note that the range is given as [0, 1].
+ // With int8, the tensor is quantized to {0, 0, 77, 204}
+ // Scale is: 1/255
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 1, 1}));
+ test::FillValues<float>(&expected, {0, 0, 77.0 / 255, 204.0 / 255});
+ test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
+}
+
// Convert a 4D tensor with unsigned 8 bits with given range.
TEST_F(QuantizeAndDequantizeTest, Convert_4D_tensor_with_uint8_range_given_V3) {
TF_ASSERT_OK(
diff --git a/tensorflow/core/kernels/quantized_resize_bilinear_op_test.cc b/tensorflow/core/kernels/quantized_resize_bilinear_op_test.cc
index e613341..6fc4894 100644
--- a/tensorflow/core/kernels/quantized_resize_bilinear_op_test.cc
+++ b/tensorflow/core/kernels/quantized_resize_bilinear_op_test.cc
@@ -273,7 +273,7 @@
<< expected_val << ", " << resized_image_val;
}
- // Value testing with reference implemenatation
+ // Value testing with reference implementation
CheckTensorValue<qint32>(image_quantized_tensor.flat<qint32>().data(),
outputs.at(0).flat<qint32>().data(),
/*batch_size=*/1,
diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc
index 73a02a3..c91bdc4 100644
--- a/tensorflow/core/kernels/stage_op.cc
+++ b/tensorflow/core/kernels/stage_op.cc
@@ -151,7 +151,7 @@
}
// Are there a limit number of elements or a memory limit
- // configued on this buffer?
+ // configured on this buffer?
bool IsBounded() const { return capacity_ > 0 || memory_limit_ > 0; }
bool IsCapacityFull() const { return buf_.size() >= capacity_; }
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index a97a71b..aa85f54 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -352,9 +352,9 @@
}
const auto key = strings::StrCat(output_handle(0), output_handle(1));
- auto creator = [this, key, tensor_array, array_size, marked_size,
- element_shape, shape_to_prepend, tensor_array_output_handle,
- output_handle](TensorArray** ret) -> Status {
+ auto creator = [key, tensor_array, array_size, marked_size, element_shape,
+ shape_to_prepend,
+ tensor_array_output_handle](TensorArray** ret) -> Status {
*ret = new TensorArray(
key, tensor_array->ElemType(), *tensor_array_output_handle,
array_size, element_shape, tensor_array->HasIdenticalElementShapes(),
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index f55562e..e07a35a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2743,6 +2743,9 @@
.Attr("range_given: bool = false")
.Output("output: T")
.Attr("T: {bfloat16, half, float, double}")
+ .Attr(
+ "round_mode: {'HALF_TO_EVEN', 'HALF_UP'} = "
+ "'HALF_TO_EVEN'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index dd1aaf9..ba0bf55 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -21942,6 +21942,18 @@
}
}
op {
+ name: "ExperimentalMatchingFilesDataset"
+ input_arg {
+ name: "patterns"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
name: "ExperimentalMaterializedIndexDatasetHandle"
output_arg {
name: "handle"
@@ -21970,6 +21982,33 @@
is_stateful: true
}
op {
+ name: "ExperimentalMaxIntraOpParallelismDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "max_intra_op_parallelism"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ExperimentalNonSerializableDataset"
input_arg {
name: "input_dataset"
@@ -22041,6 +22080,33 @@
}
}
op {
+ name: "ExperimentalPrivateThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_threads"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ExperimentalSleepDataset"
input_arg {
name: "input_dataset"
@@ -31491,18 +31557,6 @@
}
}
op {
- name: "MatchingFilesDataset"
- input_arg {
- name: "patterns"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- is_stateful: true
-}
-op {
name: "MatrixBandPart"
input_arg {
name: "input"
@@ -41733,6 +41787,71 @@
}
}
op {
+ name: "QuantizeAndDequantizeV2"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_min"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_max"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "signed_input"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ attr {
+ name: "num_bits"
+ type: "int"
+ default_value {
+ i: 8
+ }
+ }
+ attr {
+ name: "range_given"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "round_mode"
+ type: "string"
+ default_value {
+ s: "HALF_TO_EVEN"
+ }
+ allowed_values {
+ list {
+ s: "HALF_TO_EVEN"
+ s: "HALF_UP"
+ }
+ }
+ }
+}
+op {
name: "QuantizeAndDequantizeV3"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 8402f25..e7212b7 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -622,18 +622,6 @@
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("MatchingFilesDataset")
- .Input("patterns: string")
- .Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
- // stateful to inhibit constant folding.
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- // `patterns` must be a scalar or a vector.
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
- return shape_inference::ScalarShape(c);
- });
-
REGISTER_OP("SqlDataset")
.Input("driver_name: string")
.Input("data_source_name: string")
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
index 9733cf2..aebe0bf 100644
--- a/tensorflow/core/ops/experimental_dataset_ops.cc
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -86,6 +86,18 @@
.Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ExperimentalMatchingFilesDataset")
+ .Input("patterns: string")
+ .Output("handle: variant")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `patterns` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
REGISTER_OP("ExperimentalNonSerializableDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
@@ -140,6 +152,22 @@
.Input("function_buffer_resource: resource")
.SetShapeFn(shape_inference::UnknownShape);
+REGISTER_OP("ExperimentalMaxIntraOpParallelismDataset")
+ .Input("input_dataset: variant")
+ .Input("max_intra_op_parallelism: int64")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalPrivateThreadPoolDataset")
+ .Input("input_dataset: variant")
+ .Input("num_threads: int64")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("ExperimentalThreadPoolDataset")
.Input("input_dataset: variant")
.Input("thread_pool: resource")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index bc35ce7..bae50a7 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10405,6 +10405,18 @@
}
}
op {
+ name: "ExperimentalMatchingFilesDataset"
+ input_arg {
+ name: "patterns"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
name: "ExperimentalMaterializedIndexDatasetHandle"
output_arg {
name: "handle"
@@ -10433,6 +10445,33 @@
is_stateful: true
}
op {
+ name: "ExperimentalMaxIntraOpParallelismDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "max_intra_op_parallelism"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ExperimentalNonSerializableDataset"
input_arg {
name: "input_dataset"
@@ -10504,6 +10543,33 @@
}
}
op {
+ name: "ExperimentalPrivateThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_threads"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ExperimentalSleepDataset"
input_arg {
name: "input_dataset"
@@ -15938,18 +16004,6 @@
}
}
op {
- name: "MatchingFilesDataset"
- input_arg {
- name: "patterns"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- is_stateful: true
-}
-op {
name: "MatrixBandPart"
input_arg {
name: "input"
@@ -21036,6 +21090,19 @@
}
}
}
+ attr {
+ name: "round_mode"
+ type: "string"
+ default_value {
+ s: "HALF_TO_EVEN"
+ }
+ allowed_values {
+ list {
+ s: "HALF_TO_EVEN"
+ s: "HALF_UP"
+ }
+ }
+ }
}
op {
name: "QuantizeAndDequantizeV3"
diff --git a/tensorflow/core/platform/cpu_feature_guard.cc b/tensorflow/core/platform/cpu_feature_guard.cc
index 9d00aa7..2efe0c0 100644
--- a/tensorflow/core/platform/cpu_feature_guard.cc
+++ b/tensorflow/core/platform/cpu_feature_guard.cc
@@ -41,7 +41,7 @@
}
}
-// Check if CPU feature is inclued in the TensorFlow binary.
+// Check if CPU feature is included in the TensorFlow binary.
void CheckIfFeatureUnused(CPUFeature feature, const string& feature_name,
string& missing_instructions) {
if (TestCPUFeature(feature)) {
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 3a4415f..0428715 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -543,6 +543,9 @@
def tf_additional_human_readable_json_deps():
return []
+def tf_additional_logger_deps():
+ return []
+
def tf_additional_all_protos():
return ["//tensorflow/core:protos_all"]
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index cf8b477..8351362 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -297,19 +297,16 @@
// for the duration of the CUPTI API callback.
TF_STATIC_THREAD_LOCAL_POD(const char *, tls_current_annotation);
-class DeviceTracerImpl : public DeviceTracer,
- public CUPTIClient,
- public tracing::TraceCollector {
+class TraceCollectorImpl : public tracing::TraceCollector {
public:
- DeviceTracerImpl(CUPTIManager *cupti_manager);
- ~DeviceTracerImpl() override;
+ TraceCollectorImpl() { tracing::SetTraceCollector(this); }
- // DeviceTracer interface:
- Status Start() override;
- Status Stop() override;
- Status Collect(StepStatsCollector *collector) override;
+ ~TraceCollectorImpl() override {
+ DCHECK(!active_trace_session_)
+ << "Unexpected active trace session detected. ";
+ }
- // tracing::TraceCollector interface:
+ // Note the method can be called after a call to Stop().
virtual std::unique_ptr<Handle> CreateAnnotationHandle(
StringPiece name_part1, StringPiece name_part2) const {
struct Impl : public tracing::TraceCollector::Handle {
@@ -332,8 +329,7 @@
}
bool IsEnabledForAnnotations() const override {
- // We are always enabled for 'Annotations'.
- return true;
+ return active_trace_session_.load(std::memory_order_relaxed);
}
bool IsEnabledForActivities(bool is_expensive) const override {
@@ -341,6 +337,36 @@
return false;
}
+ void Start() {
+ DCHECK(!active_trace_session_)
+ << "Unexpected active trace session detected. ";
+ active_trace_session_ = true;
+ }
+
+ void Stop() {
+ DCHECK(active_trace_session_) << "No active trace session detected. ";
+ active_trace_session_ = false;
+ }
+
+ private:
+ std::atomic<bool> active_trace_session_;
+};
+
+TraceCollectorImpl *GlobalDefaultTraceCollector() {
+ static auto *instance = new TraceCollectorImpl();
+ return instance;
+}
+
+class DeviceTracerImpl : public DeviceTracer, public CUPTIClient {
+ public:
+ DeviceTracerImpl(CUPTIManager *cupti_manager);
+ ~DeviceTracerImpl() override;
+
+ // DeviceTracer interface:
+ Status Start() override;
+ Status Stop() override;
+ Status Collect(StepStatsCollector *collector) override;
+
protected:
// This callback is used exclusively by CUPTIManager.
friend class CUPTIManager;
@@ -430,7 +456,7 @@
}
// Register as a TraceEngine to receive ScopedAnnotations.
- tracing::SetTraceCollector(this);
+ GlobalDefaultTraceCollector()->Start();
// Intercept launch and memcpy calls to capture the Op name annotation.
// TODO(pbar) Add callbacks for memcpy variants.
@@ -478,7 +504,8 @@
return Status::OK();
}
CUPTI_CALL(Unsubscribe(subscriber_));
- tracing::SetTraceCollector(nullptr);
+ GlobalDefaultTraceCollector()->Stop();
+
TF_RETURN_IF_ERROR(cupti_manager_->DisableTrace());
end_walltime_us_ = NowInUsec();
CUPTI_CALL(GetTimestamp(&end_timestamp_));
diff --git a/tensorflow/core/platform/default/logger.cc b/tensorflow/core/platform/default/logger.cc
new file mode 100644
index 0000000..54b1a1a
--- /dev/null
+++ b/tensorflow/core/platform/default/logger.cc
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/logger.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+Logger* Logger::Singleton() {
+ class DefaultLogger : public Logger {
+ private:
+ void DoLogProto(google::protobuf::Any* proto) override {
+ VLOG(2) << proto->ShortDebugString();
+ }
+ void DoFlush() override {}
+ };
+ static Logger* instance = new DefaultLogger();
+ return instance;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/logging.cc b/tensorflow/core/platform/default/logging.cc
index 133ae45..26bd854 100644
--- a/tensorflow/core/platform/default/logging.cc
+++ b/tensorflow/core/platform/default/logging.cc
@@ -21,18 +21,18 @@
#include <android/log.h>
#include <iostream>
#include <sstream>
-#include <cstring>
#endif
#include <stdlib.h>
+#include <string.h>
#include <time.h>
+#include <string>
+#include <unordered_map>
+
namespace tensorflow {
namespace internal {
-LogMessage::LogMessage(const char* fname, int line, int severity)
- : fname_(fname), line_(line), severity_(severity) {}
-
#if defined(PLATFORM_POSIX_ANDROID)
void LogMessage::GenerateLogMessage() {
int android_log_level;
@@ -94,24 +94,90 @@
namespace {
+int ParseInteger(const char* str, size_t size) {
+ // Ideally we would use env_var / safe_strto64, but it is
+ // hard to use here without pulling in a lot of dependencies,
+ // so we use std:istringstream instead
+ string integer_str(str, size);
+ std::istringstream ss(integer_str);
+ int level = 0;
+ ss >> level;
+ return level;
+}
+
// Parse log level (int64) from environment variable (char*)
int64 LogLevelStrToInt(const char* tf_env_var_val) {
if (tf_env_var_val == nullptr) {
return 0;
}
+ return ParseInteger(tf_env_var_val, strlen(tf_env_var_val));
+}
- // Ideally we would use env_var / safe_strto64, but it is
- // hard to use here without pulling in a lot of dependencies,
- // so we use std:istringstream instead
- string min_log_level(tf_env_var_val);
- std::istringstream ss(min_log_level);
- int64 level;
- if (!(ss >> level)) {
- // Invalid vlog level setting, set level to default (0)
- level = 0;
+// Using StringPiece breaks Windows build.
+struct StringData {
+ struct Hasher {
+ size_t operator()(const StringData& sdata) const {
+ // For dependency reasons, we cannot use hash.h here. Use DBJHash instead.
+ size_t hash = 5381;
+ const char* data = sdata.data;
+ for (const char* top = data + sdata.size; data < top; ++data) {
+ hash = ((hash << 5) + hash) + (*data);
+ }
+ return hash;
+ }
+ };
+
+ StringData() = default;
+ StringData(const char* data, size_t size) : data(data), size(size) {}
+
+ bool operator==(const StringData& rhs) const {
+ return size == rhs.size && memcmp(data, rhs.data, size) == 0;
}
- return level;
+ const char* data = nullptr;
+ size_t size = 0;
+};
+
+using VmoduleMap = std::unordered_map<StringData, int, StringData::Hasher>;
+
+// Returns a mapping from module name to VLOG level, derived from the
+// TF_CPP_VMOUDLE environment variable; ownership is transferred to the caller.
+VmoduleMap* VmodulesMapFromEnv() {
+ // The value of the env var is supposed to be of the form:
+ // "foo=1,bar=2,baz=3"
+ const char* env = getenv("TF_CPP_VMODULE");
+ if (env == nullptr) {
+ // If there is no TF_CPP_VMODULE configuration (most common case), return
+ // nullptr so that the ShouldVlogModule() API can fast bail out of it.
+ return nullptr;
+ }
+ // The memory returned by getenv() can be invalidated by following getenv() or
+ // setenv() calls. And since we keep references to it in the VmoduleMap in
+ // form of StringData objects, make a copy of it.
+ const char* env_data = strdup(env);
+ VmoduleMap* result = new VmoduleMap();
+ while (true) {
+ const char* eq = strchr(env_data, '=');
+ if (eq == nullptr) {
+ break;
+ }
+ const char* after_eq = eq + 1;
+
+ // Comma either points at the next comma delimiter, or at a null terminator.
+ // We check that the integer we parse ends at this delimiter.
+ const char* comma = strchr(after_eq, ',');
+ const char* new_env_data;
+ if (comma == nullptr) {
+ comma = strchr(after_eq, '\0');
+ new_env_data = comma;
+ } else {
+ new_env_data = comma + 1;
+ }
+ (*result)[StringData(env_data, eq - env_data)] =
+ ParseInteger(after_eq, comma - after_eq);
+ env_data = new_env_data;
+ }
+ return result;
}
} // namespace
@@ -146,10 +212,15 @@
#endif
}
+LogMessage::LogMessage(const char* fname, int line, int severity)
+ : fname_(fname), line_(line), severity_(severity) {}
+
LogMessage::~LogMessage() {
// Read the min log level once during the first call to logging.
static int64 min_log_level = MinLogLevelFromEnv();
- if (TF_PREDICT_TRUE(severity_ >= min_log_level)) GenerateLogMessage();
+ if (severity_ >= min_log_level) {
+ GenerateLogMessage();
+ }
}
int64 LogMessage::MinVLogLevel() {
@@ -157,6 +228,24 @@
return min_vlog_level;
}
+bool LogMessage::VmoduleActivated(const char* fname, int level) {
+ if (level <= MinVLogLevel()) {
+ return true;
+ }
+ static VmoduleMap* vmodules = VmodulesMapFromEnv();
+ if (TF_PREDICT_TRUE(vmodules == nullptr)) {
+ return false;
+ }
+ const char* last_slash = strrchr(fname, '/');
+ const char* module_start = last_slash == nullptr ? fname : last_slash + 1;
+ const char* dot_after = strchr(module_start, '.');
+ const char* module_limit =
+ dot_after == nullptr ? strchr(fname, '\0') : dot_after;
+ StringData module(module_start, module_limit - module_start);
+ auto it = vmodules->find(module);
+ return it != vmodules->end() && it->second >= level;
+}
+
LogMessageFatal::LogMessageFatal(const char* file, int line)
: LogMessage(file, line, FATAL) {}
LogMessageFatal::~LogMessageFatal() {
diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h
index 08a692ff..bb8735e 100644
--- a/tensorflow/core/platform/default/logging.h
+++ b/tensorflow/core/platform/default/logging.h
@@ -46,6 +46,17 @@
// but VLOG(3) will not. Defaults to 0.
static int64 MinVLogLevel();
+ // Returns whether VLOG level lvl is activated for the file fname.
+ //
+ // E.g. if the environment variable TF_CPP_VMODULE contains foo=3 and fname is
+ // foo.cc and lvl is <= 3, this will return true. It will also return true if
+ // the level is lower or equal to TF_CPP_MIN_VLOG_LEVEL (default zero).
+ //
+ // It is expected that the result of this query will be cached in the VLOG-ing
+ // call site to avoid repeated lookups. This routine performs a hash-map
+ // access against the VLOG-ing specification provided by the env var.
+ static bool VmoduleActivated(const char* fname, int level);
+
protected:
void GenerateLogMessage();
@@ -55,6 +66,13 @@
int severity_;
};
+// Uses the lower operator & precedence to voidify a LogMessage reference, so
+// that the ternary VLOG() implementation is balanced, type wise.
+struct Voidifier {
+ template <typename T>
+ void operator&(const T&)const {}
+};
+
// LogMessageFatal ensures the process will exit in failure after
// logging this message.
class LogMessageFatal : public LogMessage {
@@ -77,18 +95,30 @@
#define LOG(severity) _TF_LOG_##severity
#ifdef IS_MOBILE_PLATFORM
+
// Turn VLOG off when under mobile devices for considerations of binary size.
#define VLOG_IS_ON(lvl) ((lvl) <= 0)
+
#else
-// Otherwise, Set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level
-// of VLOG
-#define VLOG_IS_ON(lvl) \
- ((lvl) <= ::tensorflow::internal::LogMessage::MinVLogLevel())
+
+// Otherwise, set TF_CPP_MIN_VLOG_LEVEL environment to update minimum log level
+// of VLOG, or TF_CPP_VMODULE to set the minimum log level for individual
+// translation units.
+#define VLOG_IS_ON(lvl) \
+ (([](int level, const char* fname) { \
+ static const bool vmodule_activated = \
+ ::tensorflow::internal::LogMessage::VmoduleActivated(fname, level); \
+ return vmodule_activated; \
+ })(lvl, __FILE__))
+
#endif
-#define VLOG(lvl) \
- if (TF_PREDICT_FALSE(VLOG_IS_ON(lvl))) \
- ::tensorflow::internal::LogMessage(__FILE__, __LINE__, tensorflow::INFO)
+#define VLOG(level) \
+ TF_PREDICT_TRUE(!VLOG_IS_ON(level)) \
+ ? (void)0 \
+ : ::tensorflow::internal::Voidifier() & \
+ ::tensorflow::internal::LogMessage(__FILE__, __LINE__, \
+ tensorflow::INFO)
// CHECK dies with a fatal error if condition is not true. It is *not*
// controlled by NDEBUG, so the check will be executed regardless of
diff --git a/tensorflow/core/platform/logger.h b/tensorflow/core/platform/logger.h
new file mode 100644
index 0000000..5d304be
--- /dev/null
+++ b/tensorflow/core/platform/logger.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_LOGGER_H_
+#define TENSORFLOW_CORE_PLATFORM_LOGGER_H_
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// Abstract logging interface. Contrary to logging.h, this class describes an
+// interface, not a concrete logging mechanism. This is useful when we want to
+// log anything to a non-local place, e.g. a database.
+class Logger {
+ public:
+ static Logger* Singleton();
+
+ virtual ~Logger() = default;
+
+ // Logs a typed proto.
+ template <typename ProtoType>
+ void LogProto(const ProtoType& proto) {
+ google::protobuf::Any any;
+ any.PackFrom(proto);
+ DoLogProto(&any);
+ }
+
+ // Flushes any pending log. Blocks until everything is flushed.
+ void Flush() { DoFlush(); }
+
+ private:
+ virtual void DoLogProto(google::protobuf::Any* proto) = 0;
+ virtual void DoFlush() = 0;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_LOGGER_H_
diff --git a/tensorflow/core/platform/numa_test.cc b/tensorflow/core/platform/numa_test.cc
index 8b39ecd..91789ef 100644
--- a/tensorflow/core/platform/numa_test.cc
+++ b/tensorflow/core/platform/numa_test.cc
@@ -44,7 +44,7 @@
TEST(Numa, SetNodeAffinity) {
// NOTE(tucker): This test is not reliable when executed under tap because
- // the virtual machine may not have access to all of the availble NUMA
+ // the virtual machine may not have access to all of the available NUMA
// nodes. Not sure what to do about that.
EXPECT_EQ(-1, port::NUMAGetThreadNodeAffinity());
if (port::NUMAEnabled()) {
diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc
index c7afab9..fc48cab 100644
--- a/tensorflow/core/platform/posix/posix_file_system.cc
+++ b/tensorflow/core/platform/posix/posix_file_system.cc
@@ -240,11 +240,14 @@
}
Status PosixFileSystem::CreateDir(const string& name) {
- Status result;
- if (mkdir(TranslateName(name).c_str(), 0755) != 0) {
- result = IOError(name, errno);
+ string translated = TranslateName(name);
+ if (translated.empty()) {
+ return errors::AlreadyExists(name);
}
- return result;
+ if (mkdir(translated.c_str(), 0755) != 0) {
+ return IOError(name, errno);
+ }
+ return Status::OK();
}
Status PosixFileSystem::DeleteDir(const string& name) {
diff --git a/tensorflow/core/platform/regexp.h b/tensorflow/core/platform/regexp.h
index a4eedf3..ca9ca1e 100644
--- a/tensorflow/core/platform/regexp.h
+++ b/tensorflow/core/platform/regexp.h
@@ -16,6 +16,7 @@
#ifndef TENSORFLOW_PLATFORM_REGEXP_H_
#define TENSORFLOW_PLATFORM_REGEXP_H_
+#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
@@ -23,7 +24,7 @@
defined(GOOGLE_RE2)
#include "tensorflow/core/platform/google/build_config/re2.h"
namespace tensorflow {
-typedef ::StringPiece RegexpStringPiece;
+typedef absl::string_view RegexpStringPiece;
} // namespace tensorflow
#else
diff --git a/tensorflow/core/profiler/internal/tfprof_code.cc b/tensorflow/core/profiler/internal/tfprof_code.cc
index 744e1e9..0c26855 100644
--- a/tensorflow/core/profiler/internal/tfprof_code.cc
+++ b/tensorflow/core/profiler/internal/tfprof_code.cc
@@ -183,7 +183,7 @@
// This method adds the statistics of graph nodes created by the python
// call.
void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
- // displayed leaf might not be true leaf. Retrive the true leaves for
+ // displayed leaf might not be true leaf. Retrieve the true leaves for
// stats.
std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
CHECK(!all_leaf.empty()) << node->name();
diff --git a/tensorflow/core/profiler/internal/tfprof_node.cc b/tensorflow/core/profiler/internal/tfprof_node.cc
index 86cb20d..8796234 100644
--- a/tensorflow/core/profiler/internal/tfprof_node.cc
+++ b/tensorflow/core/profiler/internal/tfprof_node.cc
@@ -151,7 +151,7 @@
}
// TODO(xpan): Make this more accurate:
- // High level: Memory tracking is suspicous and requires large scale
+ // High level: Memory tracking is suspicious and requires large scale
// clean up.
// Investigte the memory usage difference between CPU/GPU with OpViewTest.
//
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 174b588..b3dc5dc 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -291,6 +291,13 @@
// transport for client-master communication that avoids the RPC
// stack. This option is primarily for used testing the RPC stack.
bool use_rpc_for_inprocess_master = 1;
+
+ // The compression algorithm to be used. One of "deflate", "gzip".
+ string compression_algorithm = 2;
+
+ // If compression_algorithm is set, the compression level to be used.
+ // From 0 (no compression), up to 3.
+ int32 compression_level = 3;
};
// Session configuration parameters.
diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto
index 0302287..c104463 100644
--- a/tensorflow/core/protobuf/master.proto
+++ b/tensorflow/core/protobuf/master.proto
@@ -224,7 +224,7 @@
message ResetRequest {
// A list of container names, which may be empty.
//
- // If 'container' is not empty, releases resoures in the given
+ // If 'container' is not empty, releases resources in the given
// containers in all devices.
//
// If 'container' is empty, releases resources in the default
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index d68f273..515d673 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -38,7 +38,7 @@
}
// Enum controlling the number of times to run optimizers. The default is to
- // run them once.
+ // run them twice.
enum NumIterationsType {
DEFAULT_NUM_ITERS = 0;
ONE = 1;
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index b9ca8ab..89c163a 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -238,15 +238,6 @@
static Status Split(const SparseTensor& tensor, const int split_dim,
const int num_split, std::vector<SparseTensor>* result);
- template <typename T>
- ABSL_DEPRECATED(
- "Use the form of Split() that takes an output pointer and returns a "
- "status instead.")
- static std::vector<SparseTensor> Split(const SparseTensor& tensor,
- const int split_dim,
- const int num_split,
- Status* status = nullptr);
-
// Slice() will slice the input SparseTensor into a SparseTensor based on
// specified start and size. Both start and size are 1-D array with each
// element of the array representing one dimension. The start is the start
@@ -578,10 +569,9 @@
}
template <typename T>
-std::vector<SparseTensor> SparseTensor::Split(const SparseTensor& input_tensor,
- const int split_dim,
- const int num_split,
- Status* status /* = nullptr */) {
+Status SparseTensor::Split(const SparseTensor& input_tensor,
+ const int split_dim, const int num_split,
+ std::vector<SparseTensor>* result) {
std::vector<Tensor> output_indices;
std::vector<Tensor> output_values;
std::vector<TensorShape> output_shapes;
@@ -601,17 +591,15 @@
const int split_dim_size = input_tensor.shape()[split_dim];
const int split_size = split_dim_size / num_split;
- if (!(num_split > 0 && num_split <= split_dim_size) && status != nullptr) {
- *status = Status(error::INVALID_ARGUMENT,
- strings::StrCat("num_split must be in the interval (0, ",
- split_dim_size, "]"));
- return {};
+ if (!(num_split > 0 && num_split <= split_dim_size)) {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("num_split must be in the interval (0, ",
+ split_dim_size, "]"));
}
if (!(split_dim >= 0 && split_dim < num_dim)) {
- *status = Status(
+ return Status(
error::INVALID_ARGUMENT,
strings::StrCat("num_dim must be in the interval [0, ", num_dim, ")"));
- return {};
}
const int residual = split_dim_size % num_split;
@@ -649,28 +637,18 @@
}
}
- std::vector<SparseTensor> output_tensors;
- output_tensors.reserve(num_split);
+ result->clear();
+ result->reserve(num_split);
for (int i = 0; i < num_split; ++i) {
SparseTensor tensor;
Status create_status =
Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
- if (!create_status.ok() && status != nullptr) {
- *status = create_status;
- return {};
+ if (!create_status.ok()) {
+ return create_status;
}
- output_tensors.push_back(std::move(tensor));
+ result->push_back(std::move(tensor));
}
- return output_tensors;
-}
-
-template <typename T>
-Status SparseTensor::Split(const SparseTensor& input_tensor,
- const int split_dim, const int num_split,
- std::vector<SparseTensor>* result) {
- Status status;
- *result = Split<T>(input_tensor, split_dim, num_split, &status);
- return status;
+ return Status::OK();
}
template <typename T>
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index ad8a44a..55688e5 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -83,10 +83,17 @@
{
int full_index = 0;
- const auto& strides_flat = sparse.strides_tensor.flat<T>();
+ const T* const strides_flat = sparse.strides_tensor.vec<T>().data();
dense->begin_valid = sparse.begin_tensor != nullptr;
dense->end_valid = sparse.end_tensor != nullptr;
+ const T* const begin_flat = sparse.begin_tensor != nullptr
+ ? sparse.begin_tensor->vec<T>().data()
+ : nullptr;
+ const T* const end_flat = sparse.end_tensor != nullptr
+ ? sparse.end_tensor->vec<T>().data()
+ : nullptr;
+
for (int i = 0; i < sparse.dims; i++) {
if ((1 << i) & sparse.ellipsis_mask) {
// Expand the ellipsis into the appropriate indices
@@ -112,16 +119,14 @@
}
// Gather slicing spec into appropriate index
- if (sparse.begin_tensor != nullptr) {
- const auto& begin_flat = sparse.begin_tensor->flat<T>();
- dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i));
+ if (begin_flat != nullptr) {
+ dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]);
}
- if (sparse.end_tensor != nullptr) {
- const auto& end_flat = sparse.end_tensor->flat<T>();
- dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i));
+ if (end_flat != nullptr) {
+ dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]);
}
dense->strides[full_index] =
- internal::SubtleMustCopy<T>(strides_flat(i));
+ internal::SubtleMustCopy<T>(strides_flat[i]);
if (sparse.begin_mask & (1 << i)) {
dense->begin_mask |= (1 << full_index);
}
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index 2dcb57a..3709ee5 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -785,7 +785,7 @@
TF_RETURN_IF_ERROR(
ParseEntryProto(iter_->key(), iter_->value(), &entry_copy));
if (!TensorShape::IsValid(entry_copy.shape())) {
- return errors::DataLoss("Invaid tensor shape: ", key, " ",
+ return errors::DataLoss("Invalid tensor shape: ", key, " ",
ProtoShortDebugString(entry_copy.shape()));
}
@@ -895,7 +895,7 @@
BundleEntryProto entry;
TF_RETURN_IF_ERROR(ParseEntryProto(iter_->key(), iter_->value(), &entry));
if (!TensorShape::IsValid(entry.shape())) {
- return errors::DataLoss("Invaid tensor shape: ", iter_->key(), " ",
+ return errors::DataLoss("Invalid tensor shape: ", iter_->key(), " ",
ProtoShortDebugString(entry.shape()));
}
diff --git a/tensorflow/core/util/tensor_ops_util.h b/tensorflow/core/util/tensor_ops_util.h
new file mode 100644
index 0000000..615f088
--- /dev/null
+++ b/tensorflow/core/util/tensor_ops_util.h
@@ -0,0 +1,128 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device>
+Status ZerosLikeTensor(OpKernelContext* ctx, const Tensor& x, Tensor* out) {
+ AllocatorAttributes attr;
+ if (x.dtype() == DT_VARIANT) {
+ attr.set_on_host(true);
+ }
+ TF_RETURN_IF_ERROR(ctx->allocate_temp(x.dtype(), x.shape(), out, attr));
+
+ switch (out->dtype()) {
+#define DTYPE_CASE(dtype) \
+ case DataTypeToEnum<dtype>::value: \
+ /* TODO(skyewm): use SetZeroFunctor like in ZerosLikeOp? */ \
+ out->flat<dtype>().device(ctx->eigen_device<Device>()) = \
+ out->flat<dtype>().constant(dtype(0)); \
+ break;
+
+ TF_CALL_POD_TYPES(DTYPE_CASE)
+#undef DTYPE_CASE
+
+ case DT_INVALID: {
+ *out = Tensor(DT_INVALID);
+ break;
+ }
+ case DataTypeToEnum<Variant>::value: {
+ Variant* out_variant = out->scalar<Variant>().data();
+ TF_RETURN_IF_ERROR(
+ UnaryOpVariant<Device>(ctx, ZEROS_LIKE_VARIANT_UNARY_OP,
+ x.scalar<Variant>()(), out_variant));
+ break;
+ }
+ default:
+ return errors::InvalidArgument(
+ "Trying to compute zeros_like for unsupported dtype ",
+ DataTypeString(out->dtype()));
+ }
+ return Status::OK();
+}
+
+template <typename Device>
+Status BinaryAddTensors(OpKernelContext* ctx, const Tensor& a, const Tensor& b,
+ Tensor* out) {
+ if (a.dtype() == DT_INVALID) {
+ *out = b;
+ return Status::OK();
+ }
+ if (b.dtype() == DT_INVALID) {
+ *out = a;
+ return Status::OK();
+ }
+ if (a.dtype() != b.dtype()) {
+ return errors::InvalidArgument(
+ "Trying to add two tensors with incompatible element types. ",
+ "One is ", DataTypeString(a.dtype()), " and the other is ",
+ DataTypeString(b.dtype()));
+ }
+ if (a.shape() != b.shape()) {
+ // TODO(apassos) support broadcasting additions here?
+ return errors::InvalidArgument(
+ "Trying to add two tensors with incompatible element shapes. ",
+ "One is ", a.shape().DebugString(), " and the other is ",
+ b.shape().DebugString());
+ }
+
+ AllocatorAttributes attr;
+ if (a.dtype() == DT_VARIANT) {
+ attr.set_on_host(true);
+ }
+ TF_RETURN_IF_ERROR(ctx->allocate_temp(a.dtype(), a.shape(), out, attr));
+
+ switch (out->dtype()) {
+#define DTYPE_CASE(dtype) \
+ case DataTypeToEnum<dtype>::value: \
+ out->flat<dtype>().device(ctx->eigen_device<Device>()) = \
+ a.flat<dtype>() + b.flat<dtype>(); \
+ break;
+
+ TF_CALL_NUMBER_TYPES(DTYPE_CASE)
+#undef DTYPE_CASE
+
+ case DataTypeToEnum<Variant>::value: {
+ Variant* out_variant = out->scalar<Variant>().data();
+ TF_RETURN_IF_ERROR(BinaryOpVariants<Device>(
+ ctx, ADD_VARIANT_BINARY_OP, a.scalar<Variant>()(),
+ b.scalar<Variant>()(), out_variant));
+ break;
+ }
+ default:
+ return errors::InvalidArgument("Trying to add unsupported dtype ",
+ out->dtype());
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_OPS_UTIL_H_
diff --git a/tensorflow/examples/autograph/integration_tests/BUILD b/tensorflow/examples/autograph/integration_tests/BUILD
index d20c17b..2a4a0f7 100644
--- a/tensorflow/examples/autograph/integration_tests/BUILD
+++ b/tensorflow/examples/autograph/integration_tests/BUILD
@@ -22,7 +22,6 @@
"keras_test.py",
],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
"//tensorflow:tensorflow_py",
],
@@ -34,7 +33,6 @@
"list_literals_test.py",
],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
"//tensorflow:tensorflow_py",
],
diff --git a/tensorflow/examples/autograph/integration_tests/keras_test.py b/tensorflow/examples/autograph/integration_tests/keras_test.py
index 9828ac3..fc0b073 100644
--- a/tensorflow/examples/autograph/integration_tests/keras_test.py
+++ b/tensorflow/examples/autograph/integration_tests/keras_test.py
@@ -93,7 +93,7 @@
init = tf.global_variables_initializer()
with tf.Session() as sess:
- sess.run(init)
+ self.evaluate(init)
sample_input = tf.random_uniform((1, 10, 10, 1))
output = model(sample_input) # pylint: disable=not-callable
self.assertEqual(self.evaluate(output).shape, (1, 3))
diff --git a/tensorflow/examples/learn/iris_custom_decay_dnn.py b/tensorflow/examples/learn/iris_custom_decay_dnn.py
index 4a21969..73bf20f 100644
--- a/tensorflow/examples/learn/iris_custom_decay_dnn.py
+++ b/tensorflow/examples/learn/iris_custom_decay_dnn.py
@@ -76,12 +76,12 @@
classifier = tf.estimator.Estimator(model_fn=my_model)
# Train.
- train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
classifier.train(input_fn=train_input_fn, steps=1000)
# Predict.
- test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
predictions = classifier.predict(input_fn=test_input_fn)
y_predicted = np.array(list(p['class'] for p in predictions))
diff --git a/tensorflow/examples/learn/iris_custom_model.py b/tensorflow/examples/learn/iris_custom_model.py
index c6bdb86..bf34d72 100644
--- a/tensorflow/examples/learn/iris_custom_model.py
+++ b/tensorflow/examples/learn/iris_custom_model.py
@@ -73,12 +73,12 @@
classifier = tf.estimator.Estimator(model_fn=my_model)
# Train.
- train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
classifier.train(input_fn=train_input_fn, steps=1000)
# Predict.
- test_input_fn = tf.estimator.inputs.numpy_input_fn(
+ test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
predictions = classifier.predict(input_fn=test_input_fn)
y_predicted = np.array(list(p['class'] for p in predictions))
diff --git a/tensorflow/examples/tutorials/layers/cnn_mnist.py b/tensorflow/examples/tutorials/layers/cnn_mnist.py
index 1e8d7d0..670e929 100644
--- a/tensorflow/examples/tutorials/layers/cnn_mnist.py
+++ b/tensorflow/examples/tutorials/layers/cnn_mnist.py
@@ -134,7 +134,7 @@
tensors=tensors_to_log, every_n_iter=50)
# Train the model
- train_input_fn = tf.estimator.inputs.numpy_input_fn(
+ train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=train_labels,
batch_size=100,
@@ -146,11 +146,8 @@
hooks=[logging_hook])
# Evaluate the model and print results
- eval_input_fn = tf.estimator.inputs.numpy_input_fn(
- x={"x": eval_data},
- y=eval_labels,
- num_epochs=1,
- shuffle=False)
+ eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
+ x={"x": eval_data}, y=eval_labels, num_epochs=1, shuffle=False)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print(eval_results)
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index 67e42aa..6ff41ca 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -112,9 +112,17 @@
C.TF_ImportGraphDefOptionsSetPrefix(opts, cprefix)
if len(options.Device) != 0 {
- cdev := C.CString(options.Device)
- defer C.free(unsafe.Pointer(cdev))
- C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev)
+ // TODO(ashankar): Remove this error and uncomment below
+ // when a release of the C library which includes
+ // https://github.com/tensorflow/tensorflow/commit/e0af5ac53e5a8ad9b07cdd5738c0a8e12f938c4e
+ // has been made.
+ // See https://github.com/tensorflow/tensorflow/issues/23257
+ return fmt.Errorf("GraphImportOptions.Device is only supported with the TensorFlow C library versions after 1.12 (or built from master). See https://github.com/tensorflow/tensorflow/issues/23257")
+ /*
+ cdev := C.CString(options.Device)
+ defer C.free(unsafe.Pointer(cdev))
+ C.TF_ImportGraphDefOptionsSetDefaultDevice(opts, cdev)
+ */
}
buf := C.TF_NewBuffer()
@@ -174,6 +182,68 @@
return ops
}
+// AddGradients adds operations to compute the partial derivatives of the sum of tensors in y
+// with respect to tensors in x, i.e., d(y[0] + y[1] + ...) / d x[0], d(y[0] + y[1] + ... ) / d x[1] etc.
+//
+// prefix, if non-empty, is the name prefix used for all operations added to the graph to compute
+// these gradients.
+func (g *Graph) AddGradients(prefix string, y []Output, x []Output, dx []Output) ([]Output, error) {
+ var (
+ cprefix *C.char
+
+ cy = make([]C.TF_Output, len(y))
+ cx = make([]C.TF_Output, len(x))
+ cdx = make([]C.TF_Output, len(dx))
+ cdy = make([]C.TF_Output, len(x))
+
+ pcy *C.TF_Output
+ pcx *C.TF_Output
+ pcdx *C.TF_Output
+ pcdy *C.TF_Output
+
+ status = newStatus()
+ )
+
+ if len(y) > 0 {
+ pcy = &cy[0]
+ for i, o := range y {
+ cy[i] = o.c()
+ }
+ }
+ if len(x) > 0 {
+ pcx = &cx[0]
+ for i, o := range x {
+ cx[i] = o.c()
+ }
+ pcdy = &cdy[0]
+ }
+ if len(dx) > 0 {
+ pcdx = &cdx[0]
+ for i, o := range dx {
+ cdx[i] = o.c()
+ }
+ }
+
+ // If prefix is "", the C.TF_AddGradientsWithPrefix need cprefix to be nil but not ""
+ if len(prefix) != 0 {
+ cprefix = C.CString(prefix)
+ defer C.free(unsafe.Pointer(cprefix))
+ }
+
+ C.TF_AddGradientsWithPrefix(g.c, cprefix, pcy, C.int(len(y)), pcx, C.int(len(x)), pcdx, status.c, pcdy)
+
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ dy := make([]Output, len(x))
+ for i, co := range cdy {
+ op := &Operation{co.oper, g}
+ dy[i] = Output{op, int(co.index)}
+ }
+
+ return dy, nil
+}
+
// OpSpec is the specification of an Operation to be added to a Graph
// (using Graph.AddOperation).
type OpSpec struct {
diff --git a/tensorflow/go/graph_test.go b/tensorflow/go/graph_test.go
index b8d65c5..067c7db 100644
--- a/tensorflow/go/graph_test.go
+++ b/tensorflow/go/graph_test.go
@@ -19,6 +19,7 @@
import (
"bytes"
"fmt"
+ "strings"
"testing"
)
@@ -80,3 +81,260 @@
t.Error(err)
}
}
+
+func TestGraphAddGradients(t *testing.T) {
+ g := NewGraph()
+ x1, err := Placeholder(g, "x1", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ x2, err := Placeholder(g, "x2", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ op0, err := g.AddOperation(OpSpec{
+ Type: "Square",
+ Name: "y0",
+ Input: []Input{x1},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ y0 := op0.Output(0)
+ op1, err := g.AddOperation(OpSpec{
+ Type: "Square",
+ Name: "y1",
+ Input: []Input{y0},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ y1 := op1.Output(0)
+ op2, err := g.AddOperation(OpSpec{
+ Type: "AddN",
+ Input: []Input{OutputList([]Output{y0, x2})},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ y2 := op2.Output(0)
+
+ grads0, err := g.AddGradients("", []Output{y1}, []Output{x1}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(grads0) != 1 {
+ t.Fatal(len(grads0))
+ }
+ if grads0[0].DataType() != Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
+ }
+
+ grads1, err := g.AddGradients("", []Output{y2}, []Output{x1, x2}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(grads1) != 2 {
+ t.Fatal(len(grads1))
+ }
+ if grads1[0].DataType() != Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float)
+ }
+ if grads1[1].DataType() != Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), Float)
+ }
+
+ sess, err := NewSession(g, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c1, _ := NewTensor(float32(3.0))
+ c2, _ := NewTensor(float32(2.0))
+ outputs, err := sess.Run(
+ map[Output]*Tensor{x1: c1, x2: c2},
+ []Output{grads0[0], grads1[0], grads1[1]},
+ nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(outputs) != 3 {
+ t.Fatal(len(outputs))
+ }
+ if outputs[0].Value().(float32) != 108.0 {
+ t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
+ }
+ if outputs[1].Value().(float32) != 6.0 {
+ t.Fatalf("Got %v, wanted float 6.0", outputs[1].Value())
+ }
+ if outputs[2].Value().(float32) != 1.0 {
+ t.Fatalf("Got %v, wanted float 1.0", outputs[2].Value())
+ }
+}
+
+func TestGraphAddGradientsSums(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ op0, err := g.AddOperation(OpSpec{
+ Type: "Square",
+ Name: "y0",
+ Input: []Input{x},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ y0 := op0.Output(0)
+ op1, err := g.AddOperation(OpSpec{
+ Type: "Square",
+ Name: "y1",
+ Input: []Input{y0},
+ })
+ y1 := op1.Output(0)
+
+ grad, err := g.AddGradients("", []Output{y0, y1}, []Output{x}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(grad) != 1 {
+ t.Fatal(len(grad))
+ }
+ if grad[0].DataType() != Float {
+ t.Fatalf("Got DataType %v, wanted %v", grad[0].DataType(), Float)
+ }
+
+ sess, err := NewSession(g, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c, _ := NewTensor(float32(3.0))
+ outputs, err := sess.Run(
+ map[Output]*Tensor{x: c},
+ []Output{grad[0]},
+ nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if outputs[0].Value().(float32) != 114.0 {
+ t.Fatalf("Got %v, wanted float 114.0", outputs[0].Value())
+ }
+}
+
+func TestGraphAddGradientsWithInitialValues(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ op0, err := g.AddOperation(OpSpec{
+ Type: "Square",
+ Name: "y0",
+ Input: []Input{x},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ y0 := op0.Output(0)
+ op1, err := g.AddOperation(OpSpec{
+ Type: "Square",
+ Name: "y1",
+ Input: []Input{y0},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ y1 := op1.Output(0)
+
+ grads0, err := g.AddGradients("", []Output{y1}, []Output{y0}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(grads0) != 1 {
+ t.Fatal(len(grads0))
+ }
+ if grads0[0].DataType() != Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), Float)
+ }
+
+ grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, []Output{grads0[0]})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(grads1) != 1 {
+ t.Fatal(len(grads1))
+ }
+ if grads1[0].DataType() != Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), Float)
+ }
+
+ sess, err := NewSession(g, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c, _ := NewTensor(float32(3.0))
+ outputs, err := sess.Run(
+ map[Output]*Tensor{x: c},
+ []Output{grads1[0]},
+ nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if outputs[0].Value().(float32) != 108.0 {
+ t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
+ }
+}
+
+func TestGraphValidateGradientsNames(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ op0, err := g.AddOperation(OpSpec{
+ Type: "Square",
+ Name: "y0",
+ Input: []Input{x},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ y0 := op0.Output(0)
+
+ grads0, err := g.AddGradients("", []Output{y0}, []Output{x}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.HasPrefix(grads0[0].Op.Name(), "gradients/") {
+ t.Fatalf("Got name %v, wanted started with gradients/", grads0[0].Op.Name())
+ }
+
+ grads1, err := g.AddGradients("", []Output{y0}, []Output{x}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.HasPrefix(grads1[0].Op.Name(), "gradients_1/") {
+ t.Fatalf("Got name %v, wanted started with gradients_1/", grads1[0].Op.Name())
+ }
+
+ grads2, err := g.AddGradients("more_gradients", []Output{y0}, []Output{x}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.HasPrefix(grads2[0].Op.Name(), "more_gradients/") {
+ t.Fatalf("Got name %v, wanted started with more_gradients/", grads2[0].Op.Name())
+ }
+
+ grads3, err := g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.HasPrefix(grads3[0].Op.Name(), "even_more_gradients/") {
+ t.Fatalf("Got name %v, wanted started with even_more_gradients/", grads3[0].Op.Name())
+ }
+
+ _, err = g.AddGradients("even_more_gradients", []Output{y0}, []Output{x}, nil)
+ if err == nil {
+ t.Error("AddGradients should have failed if gradients name is already existing")
+ }
+}
diff --git a/tensorflow/go/op/gradients.go b/tensorflow/go/op/gradients.go
new file mode 100644
index 0000000..c595678
--- /dev/null
+++ b/tensorflow/go/op/gradients.go
@@ -0,0 +1,49 @@
+/*
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package op
+
+import (
+ "fmt"
+
+ tf "github.com/tensorflow/tensorflow/tensorflow/go"
+)
+
+// Gradients adds gradients computation ops to the graph according to scope.
+//
+// Arguments:
+// y: output of the function to derive
+// x: inputs of the function for which partial derivatives are computed
+// dx: if not null, the partial derivatives of some loss function L w.r.t. y
+//
+// return the partial derivatives
+func Gradients(scope *Scope, y []tf.Output, x []tf.Output, dx ...tf.Output) (output []tf.Output) {
+ if len(scope.controlDependencies) > 0 {
+ scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support control dependencies (via Scope.WithControlDependencies)."))
+ return
+ }
+ if scope.device != "" {
+ scope.UpdateErr("Gradients", fmt.Errorf("Gradients does not currently support device annotations (via Scope.WithDevice)."))
+ return
+ }
+
+ var err error
+ if output, err = scope.graph.AddGradients(scope.opName("Gradients"), y, x, dx); err != nil {
+ scope.UpdateErr("Gradients", err)
+ return
+ }
+ return output
+}
diff --git a/tensorflow/go/op/gradients_test.go b/tensorflow/go/op/gradients_test.go
new file mode 100644
index 0000000..3d1d57b
--- /dev/null
+++ b/tensorflow/go/op/gradients_test.go
@@ -0,0 +1,246 @@
+/*
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package op
+
+import (
+ "strings"
+ "testing"
+
+ tf "github.com/tensorflow/tensorflow/tensorflow/go"
+)
+
+func TestAddGradients(t *testing.T) {
+ var (
+ s = NewScope()
+ x1 = Placeholder(s.SubScope("x1"), tf.Float)
+ x2 = Placeholder(s.SubScope("x2"), tf.Float)
+ y0 = Square(s.SubScope("y0"), x1)
+ y1 = Square(s.SubScope("y1"), y0)
+ y2 = AddN(s.SubScope("y2"), []tf.Output{y0, x2})
+ )
+
+ grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{x1})
+ if err := s.Err(); err != nil {
+ t.Fatal(err)
+ }
+ if len(grads0) != 1 {
+ t.Fatal(len(grads0))
+ }
+ if grads0[0].DataType() != tf.Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
+ }
+
+ sub := s.SubScope("sub")
+ grads1 := Gradients(sub, []tf.Output{y2}, []tf.Output{x1, x2})
+ if err := sub.Err(); err != nil {
+ t.Fatal(err)
+ }
+ if len(grads1) != 2 {
+ t.Fatal(len(grads1))
+ }
+ if grads1[0].DataType() != tf.Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), tf.Float)
+ }
+ if grads1[1].DataType() != tf.Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads1[1].DataType(), tf.Float)
+ }
+
+ graph, err := sub.Finalize()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sess, err := tf.NewSession(graph, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c1, _ := tf.NewTensor(float32(3.0))
+ c2, _ := tf.NewTensor(float32(3.0))
+ outputs, err := sess.Run(
+ map[tf.Output]*tf.Tensor{x1: c1, x2: c2},
+ []tf.Output{grads0[0], grads1[0], grads1[1]},
+ nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(outputs) != 3 {
+ t.Fatal(len(outputs))
+ }
+ if outputs[0].Value().(float32) != 108.0 {
+ t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
+ }
+ if outputs[1].Value().(float32) != 6.0 {
+ t.Fatalf("Got %v, wanted float 6.0", outputs[1].Value())
+ }
+ if outputs[2].Value().(float32) != 1.0 {
+ t.Fatalf("Got %v, wanted float 1.0", outputs[2].Value())
+ }
+}
+
+func TestAddGradientsSums(t *testing.T) {
+ var (
+ s = NewScope()
+ x = Placeholder(s.SubScope("x"), tf.Float)
+ y0 = Square(s.SubScope("y0"), x)
+ y1 = Square(s.SubScope("y1"), y0)
+ )
+
+ grad := Gradients(s, []tf.Output{y0, y1}, []tf.Output{x})
+ if err := s.Err(); err != nil {
+ t.Fatal(err)
+ }
+ if len(grad) != 1 {
+ t.Fatal(len(grad))
+ }
+ if grad[0].DataType() != tf.Float {
+ t.Fatalf("Got DataType %v, wanted %v", grad[0].DataType(), tf.Float)
+ }
+
+ graph, err := s.Finalize()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sess, err := tf.NewSession(graph, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c, _ := tf.NewTensor(float32(3.0))
+ outputs, err := sess.Run(
+ map[tf.Output]*tf.Tensor{x: c},
+ []tf.Output{grad[0]},
+ nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if outputs[0].Value().(float32) != 114.0 {
+ t.Fatalf("Got %v, wanted float 114.0", outputs[0].Value())
+ }
+}
+
+func TestAddGradientsWithInitialValues(t *testing.T) {
+ var (
+ s = NewScope()
+ x = Placeholder(s.SubScope("x1"), tf.Float)
+ y0 = Square(s.SubScope("y0"), x)
+ y1 = Square(s.SubScope("y1"), y0)
+ )
+
+ grads0 := Gradients(s, []tf.Output{y1}, []tf.Output{y0})
+ if err := s.Err(); err != nil {
+ t.Fatal(err)
+ }
+ if len(grads0) != 1 {
+ t.Fatal(len(grads0))
+ }
+ if grads0[0].DataType() != tf.Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads0[0].DataType(), tf.Float)
+ }
+
+ sub := s.SubScope("sub")
+ grads1 := Gradients(sub, []tf.Output{y0}, []tf.Output{x}, grads0[0])
+ if err := sub.Err(); err != nil {
+ t.Fatal(err)
+ }
+ if len(grads1) != 1 {
+ t.Fatal(len(grads1))
+ }
+ if grads1[0].DataType() != tf.Float {
+ t.Fatalf("Got DataType %v, wanted %v", grads1[0].DataType(), tf.Float)
+ }
+
+ graph, err := sub.Finalize()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sess, err := tf.NewSession(graph, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c, _ := tf.NewTensor(float32(3.0))
+ outputs, err := sess.Run(
+ map[tf.Output]*tf.Tensor{x: c},
+ []tf.Output{grads1[0]},
+ nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if outputs[0].Value().(float32) != 108.0 {
+ t.Fatalf("Got %v, wanted float 108.0", outputs[0].Value())
+ }
+}
+
+func TestValidateGradientsNames(t *testing.T) {
+ var (
+ s = NewScope()
+ x = Placeholder(s.SubScope("x"), tf.Float)
+ y0 = Square(s.SubScope("y0"), x)
+ )
+
+ grads0 := Gradients(s, []tf.Output{y0}, []tf.Output{x})
+ if err := s.Err(); err != nil {
+ t.Fatal(err)
+ }
+ if !strings.HasPrefix(grads0[0].Op.Name(), "Gradients/") {
+ t.Fatalf("Got name %v, wanted started with Gradients/", grads0[0].Op.Name())
+ }
+
+ sub := s.SubScope("sub")
+ grads1 := Gradients(sub, []tf.Output{y0}, []tf.Output{x})
+ if err := s.Err(); err != nil {
+ t.Fatal(err)
+ }
+ if !strings.HasPrefix(grads1[0].Op.Name(), "sub/Gradients/") {
+ t.Fatalf("Got name %v, wanted started with sub/Gradients/", grads1[0].Op.Name())
+ }
+
+ Gradients(sub, []tf.Output{y0}, []tf.Output{x})
+ if err := s.Err(); err == nil {
+ t.Error("Gradients should have failed if executed more than once for scope of the same namespace")
+ }
+}
+
+func TestAddGradientsWithControlDependencies(t *testing.T) {
+ var (
+ s = NewScope()
+ zero = Const(s.SubScope("zero"), int32(0))
+ x = Placeholder(s.SubScope("x"), tf.Float)
+ y0 = Square(s.SubScope("y0"), x)
+ variable = VarHandleOp(s, tf.Int32, tf.ScalarShape())
+ init = AssignVariableOp(s, variable, zero)
+ readDeps = []*tf.Operation{init}
+ )
+ s = s.WithControlDependencies(readDeps...)
+ Gradients(s, []tf.Output{y0}, []tf.Output{x})
+ if err := s.Err(); err == nil {
+ t.Error("Gradients should have failed when control dependencies are set")
+ }
+}
+
+func TestAddGradientsWithDevice(t *testing.T) {
+ var (
+ s = NewScope()
+ x = Placeholder(s.SubScope("x"), tf.Float)
+ y0 = Square(s.SubScope("y0"), x)
+ )
+ s = s.WithDevice("/device:GPU:0")
+ Gradients(s, []tf.Output{y0}, []tf.Output{x})
+ if err := s.Err(); err == nil {
+ t.Error("Gradients should have failed when device is set")
+ }
+}
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 9b59c03..02a1335 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -463,6 +463,14 @@
}
}
+// QuantizeAndDequantizeV2RoundMode sets the optional round_mode attribute to value.
+// If not specified, defaults to "HALF_TO_EVEN"
+func QuantizeAndDequantizeV2RoundMode(value string) QuantizeAndDequantizeV2Attr {
+ return func(m optionalAttr) {
+ m["round_mode"] = value
+ }
+}
+
// Quantizes then dequantizes a tensor.
//
// This op simulates the precision loss from the quantized forward pass by:
@@ -3487,30 +3495,6 @@
return scope.AddOperation(opspec)
}
-// Add the quantile summaries to each quantile stream resource.
-//
-// An op that adds a list of quantile summaries to a quantile stream resource. Each
-// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
-// for a single feature.
-//
-// Arguments:
-// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
-// summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature.
-//
-// Returns the created operation.
-func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesQuantileStreamResourceAddSummaries",
- Input: []tf.Input{
- quantile_stream_resource_handle, tf.OutputList(summaries),
- },
- }
- return scope.AddOperation(opspec)
-}
-
// Makes the summary of quantiles for the batch.
//
// An op that takes a list of tensors and outputs the quantile summaries for each tensor.
@@ -5661,6 +5645,77 @@
return op.Output(0)
}
+// MapUnstageAttr is an optional argument to MapUnstage.
+type MapUnstageAttr func(optionalAttr)
+
+// MapUnstageCapacity sets the optional capacity attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func MapUnstageCapacity(value int64) MapUnstageAttr {
+ return func(m optionalAttr) {
+ m["capacity"] = value
+ }
+}
+
+// MapUnstageMemoryLimit sets the optional memory_limit attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func MapUnstageMemoryLimit(value int64) MapUnstageAttr {
+ return func(m optionalAttr) {
+ m["memory_limit"] = value
+ }
+}
+
+// MapUnstageContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func MapUnstageContainer(value string) MapUnstageAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// MapUnstageSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func MapUnstageSharedName(value string) MapUnstageAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Op removes and returns the values associated with the key
+//
+// from the underlying container. If the underlying container
+// does not contain this key, the op will block until it does.
+func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtypes": dtypes}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "MapUnstage",
+ Input: []tf.Input{
+ key, indices,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if values, idx, err = makeOutputList(op, idx, "values"); err != nil {
+ scope.UpdateErr("MapUnstage", err)
+ return
+ }
+ return values
+}
+
// Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
//
// The regularized incomplete beta integral is defined as:
@@ -30094,6 +30149,30 @@
return op.Output(0)
}
+// Add the quantile summaries to each quantile stream resource.
+//
+// An op that adds a list of quantile summaries to a quantile stream resource. Each
+// summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+// for a single feature.
+//
+// Arguments:
+// quantile_stream_resource_handle: resource handle referring to a QuantileStreamResource.
+// summaries: string; List of Rank 2 Tensor each containing the summaries for a single feature.
+//
+// Returns the created operation.
+func BoostedTreesQuantileStreamResourceAddSummaries(scope *Scope, quantile_stream_resource_handle tf.Output, summaries []tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesQuantileStreamResourceAddSummaries",
+ Input: []tf.Input{
+ quantile_stream_resource_handle, tf.OutputList(summaries),
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
// Gets the next element from a FunctionBufferingResource.
//
// Arguments:
@@ -33926,74 +34005,3 @@
}
return scope.AddOperation(opspec)
}
-
-// MapUnstageAttr is an optional argument to MapUnstage.
-type MapUnstageAttr func(optionalAttr)
-
-// MapUnstageCapacity sets the optional capacity attribute to value.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func MapUnstageCapacity(value int64) MapUnstageAttr {
- return func(m optionalAttr) {
- m["capacity"] = value
- }
-}
-
-// MapUnstageMemoryLimit sets the optional memory_limit attribute to value.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func MapUnstageMemoryLimit(value int64) MapUnstageAttr {
- return func(m optionalAttr) {
- m["memory_limit"] = value
- }
-}
-
-// MapUnstageContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func MapUnstageContainer(value string) MapUnstageAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// MapUnstageSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func MapUnstageSharedName(value string) MapUnstageAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// Op removes and returns the values associated with the key
-//
-// from the underlying container. If the underlying container
-// does not contain this key, the op will block until it does.
-func MapUnstage(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...MapUnstageAttr) (values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtypes": dtypes}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "MapUnstage",
- Input: []tf.Input{
- key, indices,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if values, idx, err = makeOutputList(op, idx, "values"); err != nil {
- scope.UpdateErr("MapUnstage", err)
- return
- }
- return values
-}
diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl
index 8255211..5d615f0 100644
--- a/tensorflow/lite/build_def.bzl
+++ b/tensorflow/lite/build_def.bzl
@@ -269,6 +269,7 @@
"pack",
"pad",
"padv2",
+ "placeholder_with_default",
"prelu",
"pow",
"range",
@@ -297,6 +298,7 @@
"squeeze",
"strided_slice",
"strided_slice_1d_exhaustive",
+ "strided_slice_buggy",
"sub",
"tile",
"topk",
diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h
index a3843de..2300ff4 100644
--- a/tensorflow/lite/builtin_ops.h
+++ b/tensorflow/lite/builtin_ops.h
@@ -125,6 +125,7 @@
kTfLiteBuiltinResizeNearestNeighbor = 97,
kTfLiteBuiltinLeakyRelu = 98,
kTfLiteBuiltinSquaredDifference = 99,
+ kTfLiteBuiltinMirrorPad = 100,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/lite/c/builtin_op_data.h b/tensorflow/lite/c/builtin_op_data.h
index 5a2f1fa..33aaac3 100644
--- a/tensorflow/lite/c/builtin_op_data.h
+++ b/tensorflow/lite/c/builtin_op_data.h
@@ -35,11 +35,21 @@
kTfLitePaddingValid,
} TfLitePadding;
+typedef enum {
+ kTfLiteMirrorPaddingUnknown = 0,
+ kTfLiteMirrorPaddingReflect,
+ kTfLiteMirrorPaddingSymmetric,
+} TfLiteMirrorPaddingMode;
+
typedef struct {
int width;
int height;
} TfLitePaddingValues;
+typedef struct {
+ TfLiteMirrorPaddingMode mode;
+} TfLiteMirrorPaddingParams;
+
// Possible fused activation functions.
// TODO(aselle): rename to TfLiteActivation
typedef enum {
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index 3b592a6..aa9b372 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -629,6 +629,19 @@
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_MIRROR_PAD: {
+ TfLiteMirrorPaddingParams* params =
+ allocator->AllocatePOD<TfLiteMirrorPaddingParams>();
+ auto* mirror_pad_params = op->builtin_options_as_MirrorPadOptions();
+ if (mirror_pad_params != nullptr) {
+ params->mode =
+ mirror_pad_params->mode() == tflite::MirrorPadMode_REFLECT
+ ? TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect
+ : TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingSymmetric;
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
// Below are the ops with no builtin_data strcture.
case BuiltinOperator_BATCH_TO_SPACE_ND:
diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h
index 9783747..e85d6df 100644
--- a/tensorflow/lite/core/subgraph.h
+++ b/tensorflow/lite/core/subgraph.h
@@ -21,6 +21,7 @@
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/memory_planner.h"
+#include "tensorflow/lite/profiling/profiler.h"
#include "tensorflow/lite/util.h"
namespace tflite {
@@ -56,7 +57,6 @@
// interpreter.
TfLiteStatus SetVariables(std::vector<int> variables);
-
// Adds a node with the given parameters and returns the index of the new
// node in `node_index` (optionally). Interpreter will take ownership of
// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
@@ -166,7 +166,6 @@
return &nodes_and_registration_[node_index];
}
-
// Change the dimensionality of a given tensor. Note, this is only acceptable
// for tensor indices that are inputs.
// Returns status of failure or success.
@@ -226,7 +225,6 @@
return kTfLiteOk;
}
-
// The default capacity of `tensors_` vector.
static constexpr int kTensorsReservedCapacity = 128;
// The capacity headroom of `tensors_` vector before calling ops'
@@ -242,6 +240,10 @@
// WARNING: This is an experimental API and subject to change.
TfLiteStatus ResetVariableTensors();
+ void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; }
+
+ profiling::Profiler* GetProfiler() { return profiler_; }
+
private:
// Prevent 'context_' from accessing functions that are only available to
// delegated kernels.
@@ -470,6 +472,9 @@
// External contexts (kTfLiteMaxExternalContexts).
TfLiteExternalContext** external_contexts_;
+
+ // Profiler for this interpreter instance.
+ profiling::Profiler* profiler_ = nullptr;
};
} // namespace tflite
diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD
index 222a043..63e8689 100644
--- a/tensorflow/lite/delegates/flex/BUILD
+++ b/tensorflow/lite/delegates/flex/BUILD
@@ -116,6 +116,7 @@
hdrs = ["delegate_data.h"],
deps = [
":buffer_map",
+ "@com_google_absl//absl/memory",
"//tensorflow/core/common_runtime/eager:context",
] + select({
"//tensorflow:android": [
diff --git a/tensorflow/lite/delegates/flex/delegate_data.cc b/tensorflow/lite/delegates/flex/delegate_data.cc
index b62479a..1483a53 100644
--- a/tensorflow/lite/delegates/flex/delegate_data.cc
+++ b/tensorflow/lite/delegates/flex/delegate_data.cc
@@ -14,20 +14,21 @@
==============================================================================*/
#include "tensorflow/lite/delegates/flex/delegate_data.h"
+#include "absl/memory/memory.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/lib/core/status.h"
namespace tflite {
namespace flex {
tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0",
&devices));
- std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
- new tensorflow::DeviceMgr(devices));
+ std::unique_ptr<tensorflow::DeviceMgr> device_mgr =
+ absl::make_unique<tensorflow::DeviceMgr>(std::move(devices));
// Note that Rendezvous is ref-counted so it will be automatically deleted.
tensorflow::Rendezvous* rendezvous =
new tensorflow::IntraProcessRendezvous(device_mgr.get());
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
index eeb48d1..9c00d05 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
@@ -111,7 +111,7 @@
# Initialize variables
init = tf.global_variables_initializer()
- sess.run(init)
+ self.evaluate(init)
for _ in range(TRAIN_STEPS):
batch_x, batch_y = self.mnist.train.next_batch(
batch_size=self.batch_size, shuffle=False)
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD
index 07fb876..799b2e5 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/BUILD
@@ -10,18 +10,46 @@
"tflite_micro_cc_test",
)
+cc_library(
+ name = "model_settings",
+ srcs = [
+ "model_settings.cc",
+ ],
+ hdrs = [
+ "model_settings.h",
+ ],
+)
+
+cc_library(
+ name = "tiny_conv_model_data",
+ srcs = [
+ "tiny_conv_model_data.cc",
+ ],
+ hdrs = [
+ "tiny_conv_model_data.h",
+ ],
+)
+
+cc_library(
+ name = "features_test_data",
+ srcs = [
+ "no_features_data.cc",
+ "yes_features_data.cc",
+ ],
+ hdrs = [
+ "no_features_data.h",
+ "yes_features_data.h",
+ ],
+)
+
tflite_micro_cc_test(
name = "micro_speech_test",
srcs = [
"micro_speech_test.cc",
- "no_features_data.cc",
- "no_features_data.h",
- "tiny_conv_model_data.cc",
- "tiny_conv_model_data.h",
- "yes_features_data.cc",
- "yes_features_data.h",
],
deps = [
+ ":features_test_data",
+ ":tiny_conv_model_data",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/kernels:all_ops_resolver",
@@ -31,46 +59,185 @@
],
)
+cc_library(
+ name = "preprocessor_test_data",
+ srcs = [
+ "no_30ms_sample_data.cc",
+ "no_power_spectrum_data.cc",
+ "yes_30ms_sample_data.cc",
+ "yes_power_spectrum_data.cc",
+ ],
+ hdrs = [
+ "no_30ms_sample_data.h",
+ "no_power_spectrum_data.h",
+ "yes_30ms_sample_data.h",
+ "yes_power_spectrum_data.h",
+ ],
+)
+
+cc_library(
+ name = "preprocessor_reference",
+ srcs = [
+ "preprocessor.cc",
+ ],
+ hdrs = [
+ "preprocessor.h",
+ ],
+ deps = [
+ ":model_settings",
+ "//tensorflow/lite/c:c_api_internal",
+ "//tensorflow/lite/experimental/micro:micro_framework",
+ ],
+)
+
tflite_micro_cc_test(
name = "preprocessor_reference_test",
srcs = [
- "no_30ms_sample_data.cc",
- "no_30ms_sample_data.h",
- "no_power_spectrum_data.cc",
- "no_power_spectrum_data.h",
- "preprocessor.cc",
- "preprocessor.h",
"preprocessor_test.cc",
- "yes_30ms_sample_data.cc",
- "yes_30ms_sample_data.h",
- "yes_power_spectrum_data.cc",
- "yes_power_spectrum_data.h",
],
deps = [
+ ":model_settings",
+ ":preprocessor_reference",
+ ":preprocessor_test_data",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",
],
)
+cc_library(
+ name = "preprocessor_fixed",
+ srcs = [
+ "fixed_point/preprocessor.cc",
+ ],
+ hdrs = [
+ "preprocessor.h",
+ ],
+ deps = [
+ ":model_settings",
+ "//tensorflow/lite/c:c_api_internal",
+ "//tensorflow/lite/experimental/micro:micro_framework",
+ ],
+)
+
tflite_micro_cc_test(
name = "preprocessor_fixed_test",
srcs = [
- "fixed_point/preprocessor.cc",
- "no_30ms_sample_data.cc",
- "no_30ms_sample_data.h",
- "no_power_spectrum_data.cc",
- "no_power_spectrum_data.h",
- "preprocessor.h",
"preprocessor_test.cc",
- "yes_30ms_sample_data.cc",
- "yes_30ms_sample_data.h",
- "yes_power_spectrum_data.cc",
- "yes_power_spectrum_data.h",
],
deps = [
+ ":model_settings",
+ ":preprocessor_fixed",
+ ":preprocessor_test_data",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/experimental/micro:micro_framework",
"//tensorflow/lite/experimental/micro/testing:micro_test",
],
)
+
+cc_library(
+ name = "audio_provider",
+ srcs = [
+ "audio_provider.cc",
+ ],
+ hdrs = [
+ "audio_provider.h",
+ ],
+ deps = [
+ ":model_settings",
+ "//tensorflow/lite/c:c_api_internal",
+ "//tensorflow/lite/experimental/micro:micro_framework",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "audio_provider_test",
+ srcs = [
+ "audio_provider_test.cc",
+ ],
+ deps = [
+ ":audio_provider",
+ ":model_settings",
+ "//tensorflow/lite/c:c_api_internal",
+ "//tensorflow/lite/experimental/micro:micro_framework",
+ "//tensorflow/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+cc_library(
+ name = "feature_provider",
+ srcs = [
+ "feature_provider.cc",
+ ],
+ hdrs = [
+ "feature_provider.h",
+ ],
+ deps = [
+ ":audio_provider",
+ ":model_settings",
+ ":preprocessor_reference",
+ ":timer",
+ "//tensorflow/lite/c:c_api_internal",
+ "//tensorflow/lite/experimental/micro:micro_framework",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "feature_provider_test",
+ srcs = [
+ "feature_provider_test.cc",
+ ],
+ deps = [
+ ":audio_provider",
+ ":feature_provider",
+ ":model_settings",
+ ":timer",
+ "//tensorflow/lite/c:c_api_internal",
+ "//tensorflow/lite/experimental/micro:micro_framework",
+ "//tensorflow/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+cc_library(
+ name = "timer",
+ srcs = [
+ "timer.cc",
+ ],
+ hdrs = [
+ "timer.h",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "timer_test",
+ srcs = [
+ "timer_test.cc",
+ ],
+ deps = [
+ ":timer",
+ "//tensorflow/lite/c:c_api_internal",
+ "//tensorflow/lite/experimental/micro:micro_framework",
+ "//tensorflow/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+cc_binary(
+ name = "micro_speech",
+ srcs = [
+ "main.cc",
+ ],
+ deps = [
+ ":audio_provider",
+ ":feature_provider",
+ ":features_test_data",
+ ":model_settings",
+ ":preprocessor_reference",
+ ":timer",
+ ":tiny_conv_model_data",
+ "//tensorflow/lite:schema_fbs_version",
+ "//tensorflow/lite/experimental/micro:micro_framework",
+ "//tensorflow/lite/experimental/micro/kernels:all_ops_resolver",
+ "//tensorflow/lite/experimental/micro/kernels:micro_ops",
+ "//tensorflow/lite/schema:schema_fbs",
+ ],
+)
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc
new file mode 100644
index 0000000..c0365d5
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.cc
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
+
+namespace {
+int16_t g_dummy_audio_data[kMaxAudioSampleSize];
+} // namespace
+
+TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter,
+ int start_ms, int duration_ms,
+ int* audio_samples_size, int16_t** audio_samples) {
+ for (int i = 0; i < kMaxAudioSampleSize; ++i) {
+ g_dummy_audio_data[i] = 0;
+ }
+ *audio_samples_size = kMaxAudioSampleSize;
+ *audio_samples = g_dummy_audio_data;
+ return kTfLiteOk;
+}
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h
new file mode 100644
index 0000000..7e2442a
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_AUDIO_PROVIDER_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_AUDIO_PROVIDER_H_
+
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
+
+// This is an abstraction around an audio source like a microphone, and is
+// expected to return 16-bit PCM sample data for a given point in time. The
+// sample data itself should be used as quickly as possible by the caller, since
+// to allow memory optimizations there are no guarantees that the samples won't
+// be overwritten by new data in the future. In practice, implementations should
+// ensure that there's a reasonable time allowed for clients to access the data
+// before any reuse.
+// The reference implementation can have no platform-specific dependencies, so
+// it just returns an array filled with zeros. For real applications, you should
+// ensure there's a specialized implementation that accesses hardware APIs.
+TfLiteStatus GetAudioSamples(tflite::ErrorReporter* error_reporter,
+ int start_ms, int duration_ms,
+ int* audio_samples_size, int16_t** audio_samples);
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_AUDIO_PROVIDER_H_
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_test.cc
new file mode 100644
index 0000000..5f7c760
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider_test.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
+#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
+#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestAudioProvider) {
+ tflite::MicroErrorReporter micro_error_reporter;
+ tflite::ErrorReporter* error_reporter = µ_error_reporter;
+
+ int audio_samples_size = 0;
+ int16_t* audio_samples = nullptr;
+ TfLiteStatus get_status =
+ GetAudioSamples(error_reporter, 0, kFeatureSliceDurationMs,
+ &audio_samples_size, &audio_samples);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, get_status);
+ TF_LITE_MICRO_EXPECT_LE(audio_samples_size, kMaxAudioSampleSize);
+ TF_LITE_MICRO_EXPECT_NE(audio_samples, nullptr);
+
+ // Make sure we can read all of the returned memory locations.
+ int total = 0;
+ for (int i = 0; i < audio_samples_size; ++i) {
+ total += audio_samples[i];
+ }
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
new file mode 100644
index 0000000..c4c52ac
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.cc
@@ -0,0 +1,121 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h"
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/audio_provider.h"
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h"
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
+
+namespace {
+// Stores the timestamp for the previous fetch of audio data, so that we can
+// avoid recalculating all the features from scratch if some earlier timeslices
+// are still present.
+int32_t g_last_time_in_ms = 0;
+// Make sure we don't try to use cached information if this is the first call
+// into the provider.
+bool g_is_first_run = true;
+} // namespace
+
+FeatureProvider::FeatureProvider(int feature_size, uint8_t* feature_data)
+ : feature_size_(feature_size), feature_data_(feature_data) {
+ // Initialize the feature data to default values.
+ for (int n = 0; n < feature_size_; ++n) {
+ feature_data_[n] = 0;
+ }
+}
+
+FeatureProvider::~FeatureProvider() {}
+
+TfLiteStatus FeatureProvider::PopulateFeatureData(
+ tflite::ErrorReporter* error_reporter, int* how_many_new_slices) {
+ if (feature_size_ != kFeatureElementCount) {
+ error_reporter->Report("Requested feature_data_ size %d doesn't match %d",
+ feature_size_, kFeatureElementCount);
+ return kTfLiteError;
+ }
+
+ const int32_t time_in_ms = TimeInMilliseconds();
+ // Quantize the time into steps as long as each window stride, so we can
+ // figure out which audio data we need to fetch.
+ const int last_step = (g_last_time_in_ms / kFeatureSliceStrideMs);
+ const int current_step = (time_in_ms / kFeatureSliceStrideMs);
+ g_last_time_in_ms = time_in_ms;
+
+ int slices_needed = current_step - last_step;
+ // If this is the first call, make sure we don't use any cached information.
+ if (g_is_first_run) {
+ g_is_first_run = false;
+ slices_needed = kFeatureSliceCount;
+ }
+ if (slices_needed > kFeatureSliceCount) {
+ slices_needed = kFeatureSliceCount;
+ }
+ *how_many_new_slices = slices_needed;
+
+ const int slices_to_keep = kFeatureSliceCount - slices_needed;
+ const int slices_to_drop = kFeatureSliceCount - slices_to_keep;
+ // If we can avoid recalculating some slices, just move the existing data
+ // up in the spectrogram, to perform something like this:
+ // last time = 80ms current time = 120ms
+ // +-----------+ +-----------+
+ // | data@20ms | --> | data@60ms |
+ // +-----------+ -- +-----------+
+ // | data@40ms | -- --> | data@80ms |
+ // +-----------+ -- -- +-----------+
+ // | data@60ms | -- -- | <empty> |
+ // +-----------+ -- +-----------+
+ // | data@80ms | -- | <empty> |
+ // +-----------+ +-----------+
+ if (slices_to_keep > 0) {
+ for (int dest_slice = 0; dest_slice < slices_to_keep; ++dest_slice) {
+ uint8_t* dest_slice_data =
+ feature_data_ + (dest_slice * kFeatureSliceSize);
+ const int src_slice = dest_slice + slices_to_drop;
+ const uint8_t* src_slice_data =
+ feature_data_ + (src_slice * kFeatureSliceSize);
+ for (int i = 0; i < kFeatureSliceSize; ++i) {
+ dest_slice_data[i] = src_slice_data[i];
+ }
+ }
+ }
+ // Any slices that need to be filled in with feature data have their
+ // appropriate audio data pulled, and features calculated for that slice.
+ if (slices_needed > 0) {
+ for (int new_slice = slices_to_keep; new_slice < kFeatureSliceCount;
+ ++new_slice) {
+ const int new_step = (current_step - kFeatureSliceCount + 1) + new_slice;
+ const int32_t slice_start_ms = (new_step * kFeatureSliceStrideMs);
+ int16_t* audio_samples = nullptr;
+ int audio_samples_size = 0;
+ GetAudioSamples(error_reporter, slice_start_ms, kFeatureSliceDurationMs,
+ &audio_samples_size, &audio_samples);
+ if (audio_samples_size < kMaxAudioSampleSize) {
+ error_reporter->Report("Audio data size %d too small, want %d",
+ audio_samples_size, kMaxAudioSampleSize);
+ return kTfLiteError;
+ }
+ uint8_t* new_slice_data = feature_data_ + (new_slice * kFeatureSliceSize);
+ TfLiteStatus preprocess_status =
+ Preprocess(error_reporter, audio_samples, audio_samples_size,
+ kFeatureSliceSize, new_slice_data);
+ if (preprocess_status != kTfLiteOk) {
+ return preprocess_status;
+ }
+ }
+ }
+ return kTfLiteOk;
+}
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h
new file mode 100644
index 0000000..a86c56e
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_
+
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
+
+// Binds itself to an area of memory intended to hold the input features for an
+// audio-recognition neural network model, and fills that data area with the
+// features representing the current audio input, for example from a microphone.
+// The audio features themselves are a two-dimensional array, made up of
+// horizontal slices representing the frequencies at one point in time, stacked
+// on top of each other to form a spectrogram showing how those frequencies
+// changed over time.
+class FeatureProvider {
+ public:
+ // Create the provider, and bind it to an area of memory. This memory should
+ // remain accessible for the lifetime of the provider object, since subsequent
+ // calls will fill it with feature data. The provider does no memory
+ // management of this data.
+ FeatureProvider(int feature_size, uint8_t* feature_data);
+ ~FeatureProvider();
+
+ // Fills the feature data with information from audio inputs, and returns how
+ // many feature slices were updated.
+ TfLiteStatus PopulateFeatureData(tflite::ErrorReporter* error_reporter,
+ int* how_many_new_slices);
+
+ private:
+ int feature_size_;
+ uint8_t* feature_data_;
+};
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_FEATURE_PROVIDER_H_
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc
new file mode 100644
index 0000000..1e52aec
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider_test.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
+#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
+#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestFeatureProvider) {
+ tflite::MicroErrorReporter micro_error_reporter;
+ tflite::ErrorReporter* error_reporter = µ_error_reporter;
+
+ uint8_t feature_data[kFeatureElementCount];
+ FeatureProvider feature_provider(kFeatureElementCount, feature_data);
+
+ int how_many_new_slices = 0;
+ TfLiteStatus populate_status = feature_provider.PopulateFeatureData(
+ error_reporter, &how_many_new_slices);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, populate_status);
+ TF_LITE_MICRO_EXPECT_EQ(kFeatureSliceCount, how_many_new_slices);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/fixed_point/preprocessor.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/fixed_point/preprocessor.cc
index de60c98..b623d8d 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_speech/fixed_point/preprocessor.cc
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/fixed_point/preprocessor.cc
@@ -31,6 +31,8 @@
#include <cmath>
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
+
namespace {
// q format notation: qx.y => 1 sign bit, x-1 integer bits, y fraction bits.
@@ -66,13 +68,6 @@
return static_cast<int32_t>(roundf(input * (1 << 30)));
}
-// These constants allow us to allocate fixed-sized arrays on the stack for our
-// working memory.
-constexpr int kInputSize = 512;
-constexpr int kAverageWindowSize = 6;
-constexpr int kOutputSize =
- ((kInputSize / 2) + (kAverageWindowSize - 1)) / kAverageWindowSize;
-
// Performs a discrete Fourier transform on the real inputs. This corresponds to
// rdft() in the FFT package at http://www.kurims.kyoto-u.ac.jp/~ooura/fft.html,
// and to kiss_fftr() in KISSFFT at https://github.com/mborgerding/kissfft.
@@ -127,14 +122,14 @@
const int16_t* input, int input_size, int output_size,
uint8_t* output) {
// Ensure our input and output data arrays are valid.
- if (input_size > kInputSize) {
+ if (input_size > kMaxAudioSampleSize) {
error_reporter->Report("Input size %d larger than %d", input_size,
- kInputSize);
+ kMaxAudioSampleSize);
return kTfLiteError;
}
- if (output_size != kOutputSize) {
+ if (output_size != kFeatureSliceSize) {
error_reporter->Report("Requested output size %d doesn't match %d",
- output_size, kOutputSize);
+ output_size, kFeatureSliceSize);
return kTfLiteError;
}
@@ -142,18 +137,17 @@
// In a real application, we'd calculate this table once in an initialization
// function and store it for repeated reuse.
// q1.15 format.
- int16_t window_function[kInputSize];
+ int16_t window_function[kMaxAudioSampleSize];
CalculatePeriodicHann(input_size, window_function);
// Apply the window function to our time series input, and pad it with zeroes
// to the next power of two.
- int32_t fixed_input[kInputSize];
- for (int i = 0; i < kInputSize; ++i) {
+ int32_t fixed_input[kMaxAudioSampleSize];
+ for (int i = 0; i < kMaxAudioSampleSize; ++i) {
if (i < input_size) {
// input is int16_t. Treat as q1.15 fixed point value in range [-1,1)
// window_function is also q1.15 fixed point number
- fixed_input[i] =
- Q1_15_FixedMultiply_Q2_30(input[i], window_function[i]);
+ fixed_input[i] = Q1_15_FixedMultiply_Q2_30(input[i], window_function[i]);
} else {
fixed_input[i] = 0;
}
@@ -161,31 +155,31 @@
// Pull the frequency data from the time series sample.
// Calculated in q10.22 format from q2.30 inputs.
- int32_t fourier_values[kInputSize];
- CalculateDiscreteFourierTransform(fixed_input, kInputSize, fourier_values);
+ int32_t fourier_values[kMaxAudioSampleSize];
+ CalculateDiscreteFourierTransform(fixed_input, kMaxAudioSampleSize,
+ fourier_values);
// We have the complex numbers giving us information about each frequency
// band, but all we want to know is how strong each frequency is, so calculate
// the squared magnitude by adding together the squares of each component.
- int32_t power_spectrum[kInputSize / 2];
- for (int i = 0; i < (kInputSize / 2); ++i) {
+ int32_t power_spectrum[kMaxAudioSampleSize / 2];
+ for (int i = 0; i < (kMaxAudioSampleSize / 2); ++i) {
const int32_t real = fourier_values[(i * 2) + 0];
const int32_t imaginary = fourier_values[(i * 2) + 1];
// q10.22 results
- power_spectrum[i] =
- Q10_22_FixedMultiply_Q10_22(real, real) +
- Q10_22_FixedMultiply_Q10_22(imaginary, imaginary);
+ power_spectrum[i] = Q10_22_FixedMultiply_Q10_22(real, real) +
+ Q10_22_FixedMultiply_Q10_22(imaginary, imaginary);
}
// Finally, reduce the size of the output by averaging together six adjacent
// frequencies into each slot, producing an array of 43 values.
// Power_spectrum numbers are q10.22. Divide by kAverageWindowSize inside
// loop to prevent overflow.
- for (int i = 0; i < kOutputSize; ++i) {
+ for (int i = 0; i < kFeatureSliceSize; ++i) {
int32_t average = 0;
for (int j = 0; j < kAverageWindowSize; ++j) {
const int index = (i * kAverageWindowSize) + j;
- if (index < (kInputSize / 2)) {
+ if (index < (kMaxAudioSampleSize / 2)) {
average += power_spectrum[index] / kAverageWindowSize;
}
}
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc
new file mode 100644
index 0000000..1890c25
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/main.cc
@@ -0,0 +1,112 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/feature_provider.h"
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
+#include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
+#include "tensorflow/lite/experimental/micro/micro_interpreter.h"
+#include "tensorflow/lite/schema/schema_generated.h"
+#include "tensorflow/lite/version.h"
+
+int main(int argc, char* argv[]) {
+ // Set up logging.
+ tflite::MicroErrorReporter micro_error_reporter;
+ tflite::ErrorReporter* error_reporter = µ_error_reporter;
+
+ // Map the model into a usable data structure. This doesn't involve any
+ // copying or parsing, it's a very lightweight operation.
+ const tflite::Model* model = ::tflite::GetModel(g_tiny_conv_model_data);
+ if (model->version() != TFLITE_SCHEMA_VERSION) {
+ error_reporter->Report(
+ "Model provided is schema version %d not equal "
+ "to supported version %d.\n",
+ model->version(), TFLITE_SCHEMA_VERSION);
+ return 1;
+ }
+
+ // This pulls in all the operation implementations we need.
+ tflite::ops::micro::AllOpsResolver resolver;
+
+ // Create an area of memory to use for input, output, and intermediate arrays.
+ // The size of this will depend on the model you're using, and may need to be
+ // determined by experimentation.
+ const int tensor_arena_size = 10 * 1024;
+ uint8_t tensor_arena[tensor_arena_size];
+ tflite::SimpleTensorAllocator tensor_allocator(tensor_arena,
+ tensor_arena_size);
+
+ // Build an interpreter to run the model with.
+ tflite::MicroInterpreter interpreter(model, resolver, &tensor_allocator,
+ error_reporter);
+
+ // Get information about the memory area to use for the model's input.
+ TfLiteTensor* model_input = interpreter.input(0);
+ if ((model_input->dims->size != 4) || (model_input->dims->data[0] != 1) ||
+ (model_input->dims->data[1] != kFeatureSliceCount) ||
+ (model_input->dims->data[2] != kFeatureSliceSize) ||
+ (model_input->type != kTfLiteUInt8)) {
+ error_reporter->Report("Bad input tensor parameters in model");
+ return 1;
+ }
+
+ // Prepare to access the audio spectrograms from a microphone or other source
+ // that will provide the inputs to the neural network.
+ FeatureProvider feature_provider(kFeatureElementCount,
+ model_input->data.uint8);
+
+ // Keep reading and analysing audio data in an infinite loop.
+ while (true) {
+ // Fetch the spectrogram for the current time.
+ int how_many_new_slices = 0;
+ TfLiteStatus feature_status = feature_provider.PopulateFeatureData(
+ error_reporter, &how_many_new_slices);
+ if (feature_status != kTfLiteOk) {
+ error_reporter->Report("Feature generation failed");
+ return 1;
+ }
+ // If no new audio samples have been received since last time, don't bother
+ // running the network model.
+ if (how_many_new_slices == 0) {
+ continue;
+ }
+
+ // Run the model on the spectrogram input and make sure it succeeds.
+ TfLiteStatus invoke_status = interpreter.Invoke();
+ if (invoke_status != kTfLiteOk) {
+ error_reporter->Report("Invoke failed");
+ return 1;
+ }
+
+ // The output from the model is a vector containing the scores for each
+ // kind of prediction, so figure out what the highest scoring category was.
+ TfLiteTensor* output = interpreter.output(0);
+ uint8_t top_category_score = 0;
+ int top_category_index = 0;
+ for (int category_index = 0; category_index < kCategoryCount;
+ ++category_index) {
+ const uint8_t category_score = output->data.uint8[category_index];
+ if (category_score > top_category_score) {
+ top_category_score = category_score;
+ top_category_index = category_index;
+ }
+ }
+
+ error_reporter->Report("Heard %s", kCategoryLabels[top_category_index]);
+ }
+
+ return 0;
+}
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc
new file mode 100644
index 0000000..b9b8fb3
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.cc
@@ -0,0 +1,23 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
+
+const char* kCategoryLabels[kCategoryCount] = {
+ "silence",
+ "unknown",
+ "yes",
+ "no",
+};
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h b/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h
new file mode 100644
index 0000000..1d8f312
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MODEL_SETTINGS_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MODEL_SETTINGS_H_
+
+// Keeping these as constant expressions allow us to allocate fixed-sized arrays
+// on the stack for our working memory.
+
+// The size of the input time series data we pass to the FFT to produce the
+// frequency information. This has to be a power of two, and since we're dealing
+// with 30ms of 16KHz inputs, which means 480 samples, this is the next value.
+constexpr int kMaxAudioSampleSize = 512;
+
+// All of these values are derived from the values used during model training,
+// if you change your model you'll need to update these constants.
+constexpr int kAverageWindowSize = 6;
+constexpr int kFeatureSliceSize =
+ ((kMaxAudioSampleSize / 2) + (kAverageWindowSize - 1)) / kAverageWindowSize;
+constexpr int kFeatureSliceCount = 49;
+constexpr int kFeatureElementCount = (kFeatureSliceSize * kFeatureSliceCount);
+constexpr int kFeatureSliceStrideMs = 20;
+constexpr int kFeatureSliceDurationMs = 30;
+
+constexpr int kCategoryCount = 4;
+constexpr int kSilenceIndex = 0;
+constexpr int kUnknownIndex = 1;
+extern const char* kCategoryLabels[kCategoryCount];
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_MODEL_SETTINGS_H_
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc
index 12f9e22..f4a7f80 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.cc
@@ -28,14 +28,9 @@
#include <cmath>
-namespace {
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/model_settings.h"
-// These constants allow us to allocate fixed-sized arrays on the stack for our
-// working memory.
-constexpr int kInputSize = 512;
-constexpr int kAverageWindowSize = 6;
-constexpr int kOutputSize =
- ((kInputSize / 2) + (kAverageWindowSize - 1)) / kAverageWindowSize;
+namespace {
// Performs a discrete Fourier transform on the real inputs. This corresponds to
// rdft() in the FFT package at http://www.kurims.kyoto-u.ac.jp/~ooura/fft.html,
@@ -78,27 +73,27 @@
const int16_t* input, int input_size, int output_size,
uint8_t* output) {
// Ensure our input and output data arrays are valid.
- if (input_size > kInputSize) {
+ if (input_size > kMaxAudioSampleSize) {
error_reporter->Report("Input size %d larger than %d", input_size,
- kInputSize);
+ kMaxAudioSampleSize);
return kTfLiteError;
}
- if (output_size != kOutputSize) {
+ if (output_size != kFeatureSliceSize) {
error_reporter->Report("Requested output size %d doesn't match %d",
- output_size, kOutputSize);
+ output_size, kFeatureSliceSize);
return kTfLiteError;
}
// Pre-calculate the window function we'll be applying to the input data.
// In a real application, we'd calculate this table once in an initialization
// function and store it for repeated reuse.
- float window_function[kInputSize];
+ float window_function[kMaxAudioSampleSize];
CalculatePeriodicHann(input_size, window_function);
// Apply the window function to our time series input, and pad it with zeroes
// to the next power of two.
- float float_input[kInputSize];
- for (int i = 0; i < kInputSize; ++i) {
+ float float_input[kMaxAudioSampleSize];
+ for (int i = 0; i < kMaxAudioSampleSize; ++i) {
if (i < input_size) {
float_input[i] =
(input[i] * window_function[i]) / static_cast<float>(1 << 15);
@@ -108,14 +103,15 @@
}
// Pull the frequency data from the time series sample.
- float fourier_values[kInputSize];
- CalculateDiscreteFourierTransform(float_input, kInputSize, fourier_values);
+ float fourier_values[kMaxAudioSampleSize];
+ CalculateDiscreteFourierTransform(float_input, kMaxAudioSampleSize,
+ fourier_values);
// We have the complex numbers giving us information about each frequency
// band, but all we want to know is how strong each frequency is, so calculate
// the squared magnitude by adding together the squares of each component.
- float power_spectrum[kInputSize / 2];
- for (int i = 0; i < (kInputSize / 2); ++i) {
+ float power_spectrum[kMaxAudioSampleSize / 2];
+ for (int i = 0; i < (kMaxAudioSampleSize / 2); ++i) {
const float real = fourier_values[(i * 2) + 0];
const float imaginary = fourier_values[(i * 2) + 1];
power_spectrum[i] = (real * real) + (imaginary * imaginary);
@@ -123,11 +119,11 @@
// Finally, reduce the size of the output by averaging together six adjacent
// frequencies into each slot, producing an array of 43 values.
- for (int i = 0; i < kOutputSize; ++i) {
+ for (int i = 0; i < kFeatureSliceSize; ++i) {
float total = 0.0f;
for (int j = 0; j < kAverageWindowSize; ++j) {
const int index = (i * kAverageWindowSize) + j;
- if (index < (kInputSize / 2)) {
+ if (index < (kMaxAudioSampleSize / 2)) {
total += power_spectrum[index];
}
}
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h
index dede2a8..adff790 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/preprocessor.h
@@ -19,6 +19,11 @@
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
+// Converts audio sample data into a more compact form that's appropriate for
+// feeding into a neural network. There are reference implementations that use
+// both floating point and fixed point available, but because the calculations
+// involved can be time-consuming, it's recommended that you use or write
+// specialized versions for your platform.
TfLiteStatus Preprocess(tflite::ErrorReporter* error_reporter,
const int16_t* input, int input_size, int output_size,
uint8_t* output);
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc
new file mode 100644
index 0000000..6c96a61
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/timer.cc
@@ -0,0 +1,22 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
+
+int32_t TimeInMilliseconds() {
+ static int current_time = 0;
+ current_time += 100;
+ return current_time;
+}
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/timer.h b/tensorflow/lite/experimental/micro/examples/micro_speech/timer.h
new file mode 100644
index 0000000..1629528
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/timer.h
@@ -0,0 +1,31 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TIMER_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TIMER_H_
+
+#include <cstdint>
+
+// Returns the time in milliseconds. There's no contract about what time zero
+// represents, the accuracy, or the granularity of the result. Subsequent calls
+// will generally not return a lower value, but even that's not guaranteed if
+// there's an overflow wraparound.
+// The reference implementation of this function just returns a constantly
+// incrementing value for each call, since it would need a non-portable platform
+// call to access time information. For real applications, you'll need to write
+// your own platform-specific implementation.
+int32_t TimeInMilliseconds();
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TIMER_H_
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/timer_test.cc b/tensorflow/lite/experimental/micro/examples/micro_speech/timer_test.cc
new file mode 100644
index 0000000..0487a12
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/timer_test.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/examples/micro_speech/timer.h"
+
+#include <limits>
+
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
+#include "tensorflow/lite/experimental/micro/testing/micro_test.h"
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestTimer) {
+ // Make sure that the technically-undefined overflow behavior we rely on below
+ // works on this platform. It's still not guaranteed, but at least this is a
+ // sanity check. Turn off when running with ASan, as it will complain about
+ // the following undefined behavior.
+#ifndef ADDRESS_SANITIZER
+ int32_t overflow_value = std::numeric_limits<int32_t>::max();
+ overflow_value += 1;
+ TF_LITE_MICRO_EXPECT_EQ(std::numeric_limits<int32_t>::min(), overflow_value);
+#endif
+
+ const int32_t first_time = TimeInMilliseconds();
+ const int32_t second_time = TimeInMilliseconds();
+
+ // It's possible that the timer may have wrapped around from +BIG_NUM to
+ // -BIG_NUM between the first and second calls, since we're storing
+ // milliseconds in a 32-bit integer. It's not reasonable that the call itself
+ // would have taken more than 2^31 milliseconds though, so look at the
+ // difference and rely on integer overflow to ensure it's accurate.
+ const int32_t time_delta = (second_time - first_time);
+ TF_LITE_MICRO_EXPECT_LE(0, time_delta);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/lite/experimental/micro/testing/micro_test.h b/tensorflow/lite/experimental/micro/testing/micro_test.h
index 10bab05..2f20dd5 100644
--- a/tensorflow/lite/experimental/micro/testing/micro_test.h
+++ b/tensorflow/lite/experimental/micro/testing/micro_test.h
@@ -153,4 +153,22 @@
} \
} while (false)
+#define TF_LITE_MICRO_EXPECT_GE(x, y) \
+ do { \
+ if ((x) < (y)) { \
+ micro_test::reporter->Report(#x " >= " #y " failed at %s:%d", __FILE__, \
+ __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#define TF_LITE_MICRO_EXPECT_LE(x, y) \
+ do { \
+ if ((x) > (y)) { \
+ micro_test::reporter->Report(#x " <= " #y " failed at %s:%d", __FILE__, \
+ __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
#endif // TENSORFLOW_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_
diff --git a/tensorflow/lite/experimental/writer/BUILD b/tensorflow/lite/experimental/writer/BUILD
index 506c668..57ce636 100644
--- a/tensorflow/lite/experimental/writer/BUILD
+++ b/tensorflow/lite/experimental/writer/BUILD
@@ -1,6 +1,9 @@
-package(default_visibility = [
- "//visibility:public",
-])
+package(
+ default_visibility = [
+ "//visibility:public",
+ ],
+ features = ["-parse_headers"],
+)
licenses(["notice"]) # Apache 2.0
diff --git a/tensorflow/lite/experimental/writer/option_writer_generator.cc b/tensorflow/lite/experimental/writer/option_writer_generator.cc
index 26d4a91..b44750e 100644
--- a/tensorflow/lite/experimental/writer/option_writer_generator.cc
+++ b/tensorflow/lite/experimental/writer/option_writer_generator.cc
@@ -67,6 +67,7 @@
"TfLitePackParams",
"TfLiteOneHotParams",
"TfLiteLeakyReluParams",
+ "TfLiteMirrorPaddingParams",
nullptr};
} // namespace
@@ -153,6 +154,7 @@
op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["MIRROR_PAD"] = ""; // TODO(karimnosseir): MirrorPadOptions.
// Manually specified mappings between ops and options (none)
op_to_option_["EMBEDDING_LOOKUP"] =
""; // TODO(aselle): maybe something else.
diff --git a/tensorflow/lite/g3doc/convert/index.md b/tensorflow/lite/g3doc/convert/index.md
index bc92a1c..60fa265 100644
--- a/tensorflow/lite/g3doc/convert/index.md
+++ b/tensorflow/lite/g3doc/convert/index.md
@@ -6,14 +6,20 @@
## From model training to device deployment
After a TensorFlow model is trained, the TensorFlow Lite converter uses that
-model to generate a TensorFlow Lite [FlatBuffer](https://google.github.io/flatbuffers/)
-file (`.tflite`). The converter supports as input:
+model to generate a TensorFlow Lite
+[FlatBuffer](https://google.github.io/flatbuffers/) file (`.tflite`). The
+converter supports as input:
[SavedModels](https://www.tensorflow.org/guide/saved_model#using_savedmodel_with_estimators),
frozen graphs (models generated by
[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py)),
-and `tf.keras` models. The TensorFlow Lite `FlatBuffer` file is deployed to a
-client device (generally a mobile or embedded device), and the TensorFlow Lite
+and `tf.keras` HDF5 models. The TensorFlow Lite `FlatBuffer` file is deployed to
+a client device (generally a mobile or embedded device), and the TensorFlow Lite
interpreter uses the compressed model for on-device inference. This conversion
process is shown in the diagram below:

+
+The TensorFlow Lite Converter can be used either from [Python](python_api.md) or
+from the [command line](cmdline_examples.md). This allows you to integrate the
+conversion step into the model design workflow, ensuring the model is easy to
+convert to a mobile inference graph.
diff --git a/tensorflow/lite/g3doc/convert/python_api.md b/tensorflow/lite/g3doc/convert/python_api.md
index 4bdf0d8..b914a34 100644
--- a/tensorflow/lite/g3doc/convert/python_api.md
+++ b/tensorflow/lite/g3doc/convert/python_api.md
@@ -3,10 +3,9 @@
This page provides examples on how to use the TensorFlow Lite Converter and the
TensorFlow Lite interpreter using the Python API.
-Note: TFLite recently moved from `tf.contrib.lite` to `tf.lite`. If you are
-using tensorflow `r1.12` or earlier you will need to add `.contrib` to the
-commands below. `tf.lite` works with newer builds, like the nightly build,
-which can be installed with: `pip install tf-nightly`
+Note: These docs describe the converter in the TensorFlow nightly release,
+installed using `pip install tf-nightly`. For docs describing older versions
+reference ["Converting models from TensorFlow 1.12"](#pre_tensorflow_1.12).
[TOC]
@@ -24,11 +23,6 @@
is `tf.lite.TFLiteConverter`. The API for calling the Python intepreter
is `tf.lite.Interpreter`.
-Note: Reference "Additional Instructions" sections for converting TensorFlow
-models to TensorFlow Lite
-[in TensorFlow 1.9 to TensorFlow 1.11](#pre_tensorflow_1.11) and
-[prior to TensorFlow 1.9](#pre_tensorflow_1.9)
-
`TFLiteConverter` provides class methods based on the original format of the
model. `TFLiteConverter.from_session()` is available for GraphDefs.
`TFLiteConverter.from_saved_model()` is available for SavedModels.
@@ -250,14 +244,13 @@
[Docker](https://www.tensorflow.org/install/docker), or
[build the pip package from source](https://www.tensorflow.org/install/source).
-### Converting models in TensorFlow 1.9 to TensorFlow 1.11 <a name="pre_tensorflow_1.11"></a>
+### Converting models from TensorFlow 1.12 <a name="pre_tensorflow_1.12"></a>
-To convert TensorFlow models to TensorFlow Lite in TensorFlow 1.9 through
-TensorFlow 1.11, use `TocoConverter`. `TocoConverter` is semantically
-identically to `TFLiteConverter`.
+Reference the following table to convert TensorFlow models to TensorFlow Lite in
+and before TensorFlow 1.12. Run `help()` to get details of each API.
-### Converting models prior to TensorFlow 1.9 <a name="pre_tensorflow_1.9"></a>
-
-To convert TensorFlow models to TensorFlow Lite in TensorFlow 1.7 and TensorFlow
-1.8, use the `toco_convert` function. Run `help(tf.lite.toco_convert)`
-to get details about accepted parameters.
+TensorFlow Version | Python API
+------------------ | ---------------------------------
+1.12 | `tf.contrib.lite.TFLiteConverter`
+1.9-1.11 | `tf.contrib.lite.TocoConverter`
+1.7-1.8 | `tf.contrib.lite.toco_convert`
diff --git a/tensorflow/lite/g3doc/devguide.md b/tensorflow/lite/g3doc/devguide.md
index 270cb8c..fdd0263 100644
--- a/tensorflow/lite/g3doc/devguide.md
+++ b/tensorflow/lite/g3doc/devguide.md
@@ -35,7 +35,7 @@
memory constrained devices, such as watches and phones, and has been successfully
used in Smart Replies on Android Wear. Currently, this model is Android-specific.
-These pre-trained models are [available for download](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/models.md)
+These pre-trained models are [available for download](models.md).
### Re-train Inception-V3 or MobileNet for a custom data set
@@ -57,51 +57,59 @@
[TensorFlow tutorials](../tutorials/) for examples of building and training
models). If you have already written a model, the first step is to export this
to a `tf.GraphDef` file. This is required because some formats do not store the
-model structure outside the code, and we must communicate with other parts of the
-framework. See
-[Exporting the Inference Graph](https://github.com/tensorflow/models/blob/master/research/slim/README.md)
-to create .pb file for the custom model.
+model structure outside the code, and we must communicate with other parts of
+the framework. See
+[Exporting the Inference Graph](https://www.tensorflow.org/tutorials/keras/save_and_restore_models#save_the_entire_model)
+to create file for the custom model.
-TensorFlow Lite currently supports a subset of TensorFlow operators. Refer to the
-[TensorFlow Lite & TensorFlow Compatibility Guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/g3doc/tf_ops_compatibility.md)
+TensorFlow Lite currently supports a subset of TensorFlow operators. Refer to
+the [TensorFlow Lite & TensorFlow Compatibility Guide](tf_ops_compatibility.md)
for supported operators and their usage. This set of operators will continue to
grow in future Tensorflow Lite releases.
-
## 2. Convert the model format
-The model generated (or downloaded) in the previous step is a *standard*
-Tensorflow model and you should now have a .pb or .pbtxt `tf.GraphDef` file.
-Models generated with transfer learning (re-training) or custom models must be
-converted—but, we must first freeze the graph to convert the model to the
-Tensorflow Lite format. This process uses several model formats:
+The [TensorFlow Lite Converter](convert/index.md) accepts the following file
+formats:
-* `tf.GraphDef` (.pb) —A protobuf that represents the TensorFlow training or
- computation graph. It contains operators, tensors, and variables definitions.
-* *CheckPoint* (.ckpt) —Serialized variables from a TensorFlow graph. Since this
- does not contain a graph structure, it cannot be interpreted by itself.
-* `FrozenGraphDef` —A subclass of `GraphDef` that does not contain
- variables. A `GraphDef` can be converted to a `FrozenGraphDef` by taking a
- CheckPoint and a `GraphDef`, and converting each variable into a constant
- using the value retrieved from the CheckPoint.
-* `SavedModel` —A `GraphDef` and CheckPoint with a signature that labels
- input and output arguments to a model. A `GraphDef` and CheckPoint can be
- extracted from a `SavedModel`.
-* *TensorFlow Lite model* (.tflite) —A serialized
- [FlatBuffer](https://google.github.io/flatbuffers/) that contains TensorFlow
- Lite operators and tensors for the TensorFlow Lite interpreter, similar to a
- `FrozenGraphDef`.
+* `SavedModel` — A `GraphDef` and checkpoint with a signature that labels
+ input and output arguments to a model. See the documentation for converting
+ SavedModels using [Python](convert/python_api.md#basic_savedmodel) or using
+ the [command line](convert/cmdline_examples.md#savedmodel).
+* `tf.keras` - A HDF5 file containing a model with weights and input and
+ output arguments generated by `tf.Keras`. See the documentation for
+ converting HDF5 models using
+ [Python](convert/python_api.md#basic_keras_file) or using the
+ [command line](convert/cmdline_examples.md#keras).
+* `frozen tf.GraphDef` — A subclass of `tf.GraphDef` that does not contain
+ variables. A `GraphDef` can be converted to a `frozen GraphDef` by taking a
+ checkpoint and a `GraphDef`, and converting each variable into a constant
+ using the value retrieved from the checkpoint. Instructions on converting a
+ `tf.GraphDef` to a TensorFlow Lite model are described in the next
+ subsection.
-### Freeze Graph
+### Converting a tf.GraphDef
-To use the `GraphDef` .pb file with TensorFlow Lite, you must have checkpoints
-that contain trained weight parameters. The .pb file only contains the structure
-of the graph. The process of merging the checkpoint values with the graph
-structure is called *freezing the graph*.
+TensorFlow models may be saved as a .pb or .pbtxt `tf.GraphDef` file. In order
+to convert the `tf.GraphDef` file to TensorFlow Lite, the model must first be
+frozen. This process invovles several file formats including the `frozen
+GraphDef`:
-You should have a checkpoints folder or download them for a pre-trained model
-(for example,
-[MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)).
+* `tf.GraphDef` (.pb or .pbtxt) — A protobuf that represents the TensorFlow
+ training or computation graph. It contains operators, tensors, and variables
+ definitions.
+* *checkpoint* (.ckpt) — Serialized variables from a TensorFlow graph. Since
+ this does not contain a graph structure, it cannot be interpreted by itself.
+* *TensorFlow Lite model* (.tflite) — A serialized
+ [FlatBuffer](https://google.github.io/flatbuffers/) that contains TensorFlow
+ Lite operators and tensors for the TensorFlow Lite interpreter.
+
+You must have checkpoints that contain trained weights. The `tf.GraphDef` file
+only contains the structure of the graph. The process of merging the checkpoint
+values with the graph structure is called *freezing the graph*.
+
+`tf.GraphDef` and checkpoint files for MobileNet models are available
+[here](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md).
To freeze the graph, use the following command (changing the arguments):
@@ -113,69 +121,53 @@
--output_node_names=MobileNetV1/Predictions/Reshape_1
```
-The `input_binary` flag must be enabled so the protobuf is read and written in
-a binary format. Set the `input_graph` and `input_checkpoint` files.
+Set the `input_binary` flag to `True` when reading a binary protobuf, a `.pb`
+file. Set to `False` for a `.pbtxt` file.
-The `output_node_names` may not be obvious outside of the code that built the
-model. The easiest way to find them is to visualize the graph, either with
-[TensorBoard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3)
-or `graphviz`.
+Set `input_graph` and `input_checkpoint` to the respective filenames. The
+`output_node_names` may not be obvious outside of the code that built the model.
+The easiest way to find them is to visualize the graph, either with
+[TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard) or
+`graphviz`.
The frozen `GraphDef` is now ready for conversion to the `FlatBuffer` format
-(.tflite) for use on Android or iOS devices. For Android, the Tensorflow
-Optimizing Converter tool supports both float and quantized models. To convert
-the frozen `GraphDef` to the .tflite format:
+(.tflite) for use on Android or iOS devices. For Android, the TensorFlow Lite
+Converter tool supports both float and quantized models. To convert the frozen
+`GraphDef` to the .tflite format use a command similar to the following:
```
-toco --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
- --input_format=TENSORFLOW_GRAPHDEF \
- --output_format=TFLITE \
+tflite_convert \
--output_file=/tmp/mobilenet_v1_1.0_224.tflite \
- --inference_type=FLOAT \
- --input_type=FLOAT \
+ --graph_def_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \
--input_arrays=input \
- --output_arrays=MobilenetV1/Predictions/Reshape_1 \
- --input_shapes=1,224,224,3
+ --output_arrays=MobilenetV1/Predictions/Reshape_1
```
-The `input_file` argument should reference the frozen `GraphDef` file
-containing the model architecture. The [frozen_graph.pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)
-file used here is available for download. `output_file` is where the TensorFlow
-Lite model will get generated. The `input_type` and `inference_type`
-arguments should be set to `FLOAT`, unless converting a
-<a href="https://www.tensorflow.org/performance/quantization">quantized model</a>.
-Setting the `input_array`, `output_array`, and `input_shape` arguments are not as
-straightforward. The easiest way to find these values is to explore the graph
-using Tensorboard. Reuse the arguments for specifying the output nodes for
-inference in the `freeze_graph` step.
+The
+[frozen_graph.pb](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)
+file used here is available for download. Setting the `input_array` and
+`output_array` arguments is not straightforward. The easiest way to find these
+values is to explore the graph using
+[TensorBoard](https://www.tensorflow.org/guide/summaries_and_tensorboard). Reuse
+the arguments for specifying the output nodes for inference in the
+`freeze_graph` step.
-It is also possible to use the Tensorflow Optimizing Converter with protobufs
-from either Python or from the command line (see the
-[toco_from_protos.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco/python/toco_from_protos.py)
-example). This allows you to integrate the conversion step into the model design
-workflow, ensuring the model is easily convertible to a mobile inference graph.
-For example:
+### Full converter reference
-```python
-import tensorflow as tf
+The [TensorFlow Lite Converter](convert/index.md) can be
+[Python](convert/python_api.md) or from the
+[command line](convert/cmdline_examples.md). This allows you to integrate the
+conversion step into the model design workflow, ensuring the model is easy to
+convert to a mobile inference graph.
-img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
-val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
-out = tf.identity(val, name="out")
+### Ops compatibility
-with tf.Session() as sess:
- tflite_model = tf.lite.toco_convert(sess.graph_def, [img], [out])
- open("converteds_model.tflite", "wb").write(tflite_model)
-```
-
-For usage, see the Tensorflow Optimizing Converter
-[command-line examples](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/toco/g3doc/cmdline_examples.md).
-
-Refer to the
-[Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/g3doc/tf_ops_compatibility.md)
-for troubleshooting help, and if that doesn't help, please
+Refer to the [ops compatibility guide](tf_ops_compatibility.md) for
+troubleshooting help, and if that doesn't help, please
[file an issue](https://github.com/tensorflow/tensorflow/issues).
+### Graph vizualization tool
+
The [development repo](https://github.com/tensorflow/tensorflow) contains a tool
to visualize TensorFlow Lite models after conversion. To build the
[visualize.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/visualize.py)
@@ -212,8 +204,8 @@
### iOS
To integrate a TensorFlow model in an iOS app, see the
-[TensorFlow Lite for iOS](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/g3doc/ios.md)
-guide and <a href="./demo_ios.md">iOS demo</a> guide.
+[TensorFlow Lite for iOS](ios.md) guide and <a href="./demo_ios.md">iOS demo</a>
+guide.
#### Core ML support
@@ -227,6 +219,5 @@
### Raspberry Pi
Compile Tensorflow Lite for a Raspberry Pi by following the
-[RPi build instructions](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/rpi.md)
-This compiles a static library file (`.a`) used to build your app. There are
-plans for Python bindings and a demo app.
+[RPi build instructions](rpi.md) This compiles a static library file (`.a`) used
+to build your app. There are plans for Python bindings and a demo app.
diff --git a/tensorflow/lite/g3doc/tf_ops_compatibility.md b/tensorflow/lite/g3doc/tf_ops_compatibility.md
index 5a7bc2d..2864c6a 100644
--- a/tensorflow/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/lite/g3doc/tf_ops_compatibility.md
@@ -1,4 +1,3 @@
-
# TensorFlow Lite & TensorFlow Compatibility Guide
TensorFlow Lite supports a number of TensorFlow operations used in common
@@ -155,6 +154,30 @@
}
```
+**ARG_MAX**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of maximum values.
+}
+```
+
+**ARG_MIN**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: A tensor of indices of minium values.
+}
+```
+
**AVERAGE_POOL_2D**
```
@@ -281,6 +304,18 @@
}
```
+**FILL**
+
+```
+Inputs {
+ 0: a 1D tensor
+ 1: a 0D (scalar) tensor
+}
+Outputs {
+ 0: A tensor of shape `tensor 0` filled with the value in `tensor 1`.
+}
+```
+
**FLOOR**
```
@@ -292,6 +327,30 @@
}
```
+**FLOOR_DIV**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: result of computing element-wise floor of `tensor 0` divided by `tensor 1`.
+}
+```
+
+**FLOOR_MOD**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: result of computing element-wise floor of `tensor 0` modulo `tensor 1`.
+}
+```
+
**FULLY_CONNECTED**
```
@@ -393,6 +452,20 @@
}
```
+**LEAKY_RELU**
+
+```
+Inputs {
+ 0: a tensor
+}
+Outputs {
+ 0: a tensor equivalent to max(input, input * alpha)
+}
+Options {
+ alpha
+}
+```
+
**LESS**
```
@@ -436,6 +509,18 @@
}
```
+**LOGICAL_OR**
+
+```
+Inputs {
+ 0: a list of tensors.
+ 1: a list of tensors.
+}
+Outputs {
+ 0: A tensor of logical_or output tensors.
+}
+```
+
**LOGISTIC**
```
@@ -513,6 +598,18 @@
}
```
+**PACK**
+
+```
+Inputs {
+ 0: a list of tensors.
+ 1: an integer.
+}
+Outputs {
+ 0: A tensor of stacked tensors.
+}
+```
+
**PAD**
```
@@ -554,6 +651,35 @@
}
```
+**POW**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: elementwise pow of the input tensors
+}
+```
+
+**RANGE**
+
+```
+Inputs {
+ 0: a 0D (scalar) tensor
+ 1: a 0D (scalar) tensor
+ 2: a 0D (scalar) tensor
+}
+Outputs {
+ 0: A 1D tensor of type `dtype` defined by a sequence where `tensor 0` is the
+ start, `tensor 1` is the limit, and `tensor 2` is the delta.
+}
+Options {
+ dtype
+}
+```
+
**RELU**
```
@@ -602,6 +728,22 @@
}
```
+**RESIZE_NEAREST_NEIGHBOR**
+
+```
+Inputs {
+ 0: a 4D tensor
+ 1: a 1D tensor with 2 elements
+}
+Outputs {
+ 0: A tensor of type `tensor 0` resized according to `tensor 1` heigh/width values
+ using nearest neighbors interpolation.
+}
+Options {
+ align_corners
+}
+```
+
**RSQRT**
```
@@ -796,66 +938,6 @@
}
```
-**POW**
-
-```
-Inputs {
- 0: a tensor
- 1: a tensor
-}
-Outputs {
- 0: elementwise pow of the input tensors
-}
-```
-
-**ARG_MAX**
-
-```
-Inputs {
- 0: a tensor
- 1: a tensor
-}
-Outputs {
- 0: A tensor of indices of maximum values.
-}
-```
-
-**ARG_MIN**
-
-```
-Inputs {
- 0: a tensor
- 1: a tensor
-}
-Outputs {
- 0: A tensor of indices of minium values.
-}
-```
-
-**PACK**
-
-```
-Inputs {
- 0: a list of tensors.
- 1: an integer.
-}
-Outputs {
- 0: A tensor of stacked tensors.
-}
-```
-
-**LOGICAL_OR**
-
-```
-Inputs {
- 0: a list of tensors.
- 1: a list of tensors.
-}
-Outputs {
- 0: A tensor of logical_or output tensors.
-}
-```
-
**UNPACK**
```
@@ -869,18 +951,6 @@
}
```
-**FLOOR_DIV**
-
-```
-Inputs {
- 0: a list of tensors.
- 1: a list of tensors.
-}
-Outputs {
- 0: A tensor of floor_div output tensors.
-}
-```
-
**ZEROS_LIKE**
```
diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc
index 4f4a999..326aff5 100644
--- a/tensorflow/lite/interpreter.cc
+++ b/tensorflow/lite/interpreter.cc
@@ -35,7 +35,7 @@
Interpreter::Interpreter(ErrorReporter* error_reporter)
: error_reporter_(error_reporter ? error_reporter
: DefaultErrorReporter()) {
- subgraphs_.emplace_back(error_reporter_, external_contexts_);
+ subgraphs_.emplace_back(new Subgraph(error_reporter_, external_contexts_));
context_ = primary_subgraph().context();
// Reserve some space for the tensors to avoid excessive resizing.
@@ -136,7 +136,7 @@
void Interpreter::SetNumThreads(int num_threads) {
for (auto& subgraph : subgraphs_) {
- subgraph.context()->recommended_num_threads = num_threads;
+ subgraph->context()->recommended_num_threads = num_threads;
}
for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
@@ -149,7 +149,7 @@
void Interpreter::SetAllowFp16PrecisionForFp32(bool allow) {
for (auto& subgraph : subgraphs_) {
- subgraph.context()->allow_fp32_relax_to_fp16 = allow;
+ subgraph->context()->allow_fp32_relax_to_fp16 = allow;
}
}
@@ -190,4 +190,12 @@
return kTfLiteOk;
}
+void Interpreter::SetProfiler(profiling::Profiler* profiler) {
+ for (auto& subgraph : subgraphs_) subgraph->SetProfiler(profiler);
+}
+
+profiling::Profiler* Interpreter::GetProfiler() {
+ return primary_subgraph().GetProfiler();
+}
+
} // namespace tflite
diff --git a/tensorflow/lite/interpreter.h b/tensorflow/lite/interpreter.h
index d89afff..405cf64 100644
--- a/tensorflow/lite/interpreter.h
+++ b/tensorflow/lite/interpreter.h
@@ -380,9 +380,9 @@
TfLiteBufferHandle* buffer_handle,
TfLiteDelegate** delegate);
- void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; }
+ void SetProfiler(profiling::Profiler* profiler);
- profiling::Profiler* GetProfiler() { return profiler_; }
+ profiling::Profiler* GetProfiler();
// The default capacity of `tensors_` vector.
static constexpr int kTensorsReservedCapacity = 128;
@@ -427,19 +427,13 @@
friend class InterpreterTest;
Subgraph& primary_subgraph() {
- return subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
+ return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
}
const Subgraph& primary_subgraph() const {
- return subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
+ return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry.
}
- // Tensors needed by the interpreter. Use `AddTensors` to add more blank
- // tensor entries. Note, `tensors_.data()` needs to be synchronized to the
- // `context_` whenever this std::vector is reallocated. Currently this
- // only happens in `AddTensors()`.
- // std::vector<TfLiteTensor> tensors_;
-
// Set the value of an external context.
static void SetExternalContext(struct TfLiteContext* context,
TfLiteExternalContextType type,
@@ -472,14 +466,11 @@
bool allow_buffer_handle_output_ = false;
- // Profiler for this interpreter instance.
- profiling::Profiler* profiler_ = nullptr;
-
// List of active external contexts.
TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
// Subgraphs
- std::vector<Subgraph> subgraphs_;
+ std::vector<std::unique_ptr<Subgraph>> subgraphs_;
};
} // namespace tflite
diff --git a/tensorflow/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc
index 0c14b9e..1fd870b 100644
--- a/tensorflow/lite/kernels/conv.cc
+++ b/tensorflow/lite/kernels/conv.cc
@@ -24,7 +24,6 @@
#include "tensorflow/lite/c/c_api_internal.h"
#include "tensorflow/lite/kernels/eigen_support.h"
#include "tensorflow/lite/kernels/gemm_support.h"
-#include "tensorflow/lite/kernels/internal/optimized/cblas_conv.h"
#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
@@ -491,11 +490,10 @@
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
KernelType effective_kernel_type;
- if ((kernel_type == kMultithreadOptimized ||
- kernel_type == kCblasOptimized) &&
+ if ((kernel_type == kMultithreadOptimized) &&
(params->dilation_width_factor != 1 ||
params->dilation_height_factor != 1)) {
- // kMultithreadOptimized and kCblasOptimized do not support dilation.
+ // kMultithreadOptimized does not support dilation.
// Therefore, fallback to optimized.
effective_kernel_type = kGenericOptimized;
} else {
@@ -521,6 +519,7 @@
GetTensorData<float>(im2col));
break;
}
+ case kCblasOptimized:
case kGenericOptimized: {
optimized_ops::Conv(op_params, GetTensorShape(input),
GetTensorData<float>(input), GetTensorShape(filter),
@@ -546,15 +545,6 @@
GetTensorData<float>(im2col));
break;
}
- case kCblasOptimized: {
- cblas_ops::Conv(op_params, GetTensorShape(input),
- GetTensorData<float>(input), GetTensorShape(filter),
- GetTensorData<float>(filter), GetTensorShape(bias),
- GetTensorData<float>(bias), GetTensorShape(output),
- GetTensorData<float>(output), GetTensorShape(im2col),
- GetTensorData<float>(im2col));
- break;
- }
}
}
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 6d9690e..7d2653f 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -234,8 +234,6 @@
cc_library(
name = "optimized",
hdrs = [
- "optimized/cblas_conv.h",
- "optimized/cblas_reference.h",
"optimized/eigen_spatial_convolutions.h",
"optimized/eigen_tensor_reduced_instantiations_oss.h",
"optimized/multithreaded_conv.h",
diff --git a/tensorflow/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/lite/kernels/internal/optimized/cblas_conv.h
deleted file mode 100644
index 5377205..0000000
--- a/tensorflow/lite/kernels/internal/optimized/cblas_conv.h
+++ /dev/null
@@ -1,109 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_
-#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_
-
-// The Conv implementation based on CBLAS interface. This is only used on iOS
-// for now, utilizing Apple's Accelerate framework.
-
-#if TFLITE_USE_APPLE_ACCELERATE_FOR_CONV
-#include <Accelerate/Accelerate.h>
-#else
-#include "tensorflow/lite/kernels/internal/optimized/cblas_reference.h"
-#endif
-
-#include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h"
-#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
-
-namespace tflite {
-namespace cblas_ops {
-
-inline void Conv(const ConvParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& filter_shape,
- const float* filter_data, const RuntimeShape& bias_shape,
- const float* bias_data, const RuntimeShape& output_shape,
- float* output_data, const RuntimeShape& im2col_shape,
- float* im2col_data) {
- const int stride_width = params.stride_width;
- const int stride_height = params.stride_height;
- const int pad_width = params.padding_values.width;
- const int pad_height = params.padding_values.height;
- const int dilation_width_factor = params.dilation_width_factor;
- const int dilation_height_factor = params.dilation_height_factor;
- const float output_activation_min = params.float_activation_min;
- const float output_activation_max = params.float_activation_max;
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- gemmlowp::ScopedProfilingLabel label("Conv/cblas");
-
- const float* gemm_input_data = nullptr;
- const RuntimeShape* gemm_input_shape = nullptr;
- const int filter_width = filter_shape.Dims(2);
- const int filter_height = filter_shape.Dims(1);
- const bool need_im2col = stride_width != 1 || stride_height != 1 ||
- filter_width != 1 || filter_height != 1;
- if (need_im2col) {
- TFLITE_DCHECK(im2col_data);
- ConvParams op_params;
- op_params.padding_type = PaddingType::kSame;
- op_params.padding_values.width = pad_width;
- op_params.padding_values.height = pad_height;
- op_params.stride_width = stride_width;
- op_params.stride_height = stride_height;
- op_params.dilation_width_factor = dilation_width_factor;
- op_params.dilation_height_factor = dilation_height_factor;
- optimized_ops::Im2col(op_params, filter_height, filter_width, 0,
- input_shape, input_data, im2col_shape, im2col_data);
-
- gemm_input_data = im2col_data;
- gemm_input_shape = &im2col_shape;
- } else {
- TFLITE_DCHECK(!im2col_data);
- gemm_input_data = input_data;
- gemm_input_shape = &input_shape;
- }
-
- // The following code computes matrix multiplication c = a * transponse(b)
- // with CBLAS, where:
- // * `a` is a matrix with dimensions (m, k).
- // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
- // * `c` is a matrix with dimensions (m, n).
- // The naming of variables are aligned with CBLAS specification here.
- const float* a = gemm_input_data;
- const float* b = filter_data;
- float* c = output_data;
- const int gemm_input_dims = gemm_input_shape->DimensionsCount();
- int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
- int n = output_shape.Dims(3);
- int k = gemm_input_shape->Dims(gemm_input_dims - 1);
- // The stride of matrix a, b and c respectively.
- int stride_a = k;
- int stride_b = k;
- int stride_c = n;
-
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
- stride_a, b, stride_b, 0.0f, c, stride_c);
-
- optimized_ops::AddBiasAndEvalActivationFunction(
- output_activation_min, output_activation_max, bias_shape, bias_data,
- output_shape, output_data);
-}
-
-} // namespace cblas_ops
-} // namespace tflite
-
-#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_
diff --git a/tensorflow/lite/kernels/internal/optimized/cblas_reference.h b/tensorflow/lite/kernels/internal/optimized/cblas_reference.h
deleted file mode 100644
index fa07578..0000000
--- a/tensorflow/lite/kernels/internal/optimized/cblas_reference.h
+++ /dev/null
@@ -1,69 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_
-#define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_
-
-#include "tensorflow/lite/kernels/internal/compatibility.h"
-
-// The reference implementation for a small subset of CBLAS interface.
-// This is only used for testing CBLAS implementation, and should never be used
-// in production code.
-
-namespace tflite {
-namespace cblas_ops {
-
-// The following code follows the original CBLAS specification, and it might
-// conflict with the TensorFlow naming convention.
-// TODO(ycling): Find another way to test CBLAS with bazel, without writing
-// a reference implementation by ourselves.
-enum CBLAS_ORDER { CblasRowMajor = 0, CblasColMajor = 1 };
-
-enum CBLAS_TRANSPOSE { CblasNoTrans = 0, CblasTrans = 1, CblasConjTrans = 2 };
-
-// A reference implementation for matrix multiplication.
-// The following code computes, c = a * transponse(b) matrix multiplication
-// with CBLAS, where:
-// * `a` is a matrix with dimensions (m, k).
-// * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
-// * `c` is a matrix with dimensions (m, n).
-// The naming of variables is aligned with CBLAS specification here.
-void cblas_sgemm(const enum CBLAS_ORDER order,
- const enum CBLAS_TRANSPOSE trans_a,
- const enum CBLAS_TRANSPOSE trans_b, const int m, const int n,
- const int k, const float alpha, const float *a,
- const int stride_a, const float *b, const int stride_b,
- const float beta, float *c, const int stride_c) {
- TFLITE_DCHECK(order == CblasRowMajor);
- TFLITE_DCHECK(trans_a == CblasNoTrans);
- TFLITE_DCHECK(trans_b == CblasTrans);
- TFLITE_DCHECK(beta == 0.0f);
- for (int row = 0; row < m; ++row) {
- for (int col = 0; col < n; ++col) {
- // If `beta` non-zero, multiple it with the original values in output.
- // Otherwise, ignore the original value in output completely.
- float value = 0.0f;
- for (int idx = 0; idx < k; ++idx) {
- value += alpha * a[stride_a * row + idx] * b[stride_b * col + idx];
- }
- c[stride_c * row + col] = value;
- }
- }
-}
-
-} // namespace cblas_ops
-} // namespace tflite
-
-#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_
diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
index 4ff8750..df335e9 100644
--- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h
@@ -25,6 +25,10 @@
#include <tuple>
#include <type_traits>
+#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
+#include <Accelerate/Accelerate.h>
+#endif
+
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "fixedpoint/fixedpoint.h"
@@ -1868,18 +1872,45 @@
gemm_input_shape = &input_shape;
}
- const auto im2col_matrix_map =
- MapAsMatrixWithLastDimAsRows(gemm_input_data, *gemm_input_shape);
- const auto filter_matrix_map =
- MapAsMatrixWithFirstDimAsCols(filter_data, filter_shape);
- auto output_matrix_map =
- MapAsMatrixWithLastDimAsRows(output_data, output_shape);
+ // The following code computes matrix multiplication c = a * transponse(b)
+ // with CBLAS, where:
+ // * `a` is a matrix with dimensions (m, k).
+ // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n).
+ // * `c` is a matrix with dimensions (m, n).
+ // The naming of variables are aligned with CBLAS specification here.
+ const float* a = gemm_input_data;
+ const float* b = filter_data;
+ float* c = output_data;
+ const int gemm_input_dims = gemm_input_shape->DimensionsCount();
+ int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
+ int n = output_shape.Dims(3);
+ int k = gemm_input_shape->Dims(gemm_input_dims - 1);
- Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map);
+#if defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
+ // The stride of matrix a, b and c respectively.
+ int stride_a = k;
+ int stride_b = k;
+ int stride_c = n;
- AddBiasAndEvalActivationFunction(output_activation_min, output_activation_max,
- bias_shape, bias_data, output_shape,
- output_data);
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a,
+ stride_a, b, stride_b, 0.0f, c, stride_c);
+#else
+ // When an optimized CBLAS implementation is not available, fall back
+ // to using Eigen.
+ typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
+ Matrix;
+ typedef Eigen::Map<Matrix> MatrixRef;
+ typedef Eigen::Map<const Matrix> ConstMatrixRef;
+
+ MatrixRef matrix_c(c, m, n);
+ ConstMatrixRef matrix_a(a, m, k);
+ ConstMatrixRef matrix_b(b, n, k);
+ matrix_c.noalias() = matrix_a * matrix_b.transpose();
+#endif // defined(TF_LITE_USE_CBLAS) && defined(__APPLE__)
+
+ optimized_ops::AddBiasAndEvalActivationFunction(
+ output_activation_min, output_activation_max, bias_shape, bias_data,
+ output_shape, output_data);
}
inline void HybridConv(const ConvParams& params, float* scaling_factors_ptr,
@@ -4293,7 +4324,6 @@
using FixedPointScaledDiff =
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
- using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index fd37865..be766ea 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -2736,7 +2736,6 @@
using FixedPointScaledDiff =
gemmlowp::FixedPoint<int32, kScaledDiffIntegerBits>;
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
- using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
@@ -3664,8 +3663,10 @@
const RuntimeShape& unextended_output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("Mean");
- TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ // Current implementation only supports dimension equals 4 and simultaneous
+ // reduction over width and height.
+ TFLITE_CHECK_EQ(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_CHECK_LE(unextended_output_shape.DimensionsCount(), 4);
const RuntimeShape input_shape =
RuntimeShape::ExtendedShape(4, unextended_input_shape);
const RuntimeShape output_shape =
@@ -3679,8 +3680,6 @@
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
- // The current implementation only supports simultaneous reduction over
- // width and height.
TFLITE_DCHECK_EQ(op_params.axis_count, 2);
TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
(op_params.axis[0] == 2 && op_params.axis[1] == 1));
diff --git a/tensorflow/lite/kernels/pooling_test.cc b/tensorflow/lite/kernels/pooling_test.cc
index 80eef02..98777f1 100644
--- a/tensorflow/lite/kernels/pooling_test.cc
+++ b/tensorflow/lite/kernels/pooling_test.cc
@@ -67,6 +67,10 @@
QuantizeAndPopulate<uint8_t>(input_, data);
}
+ void SetInput(const std::vector<float>& data) {
+ QuantizeAndPopulate<uint8_t>(input_, data);
+ }
+
std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
std::vector<float> GetDequantizedOutput() {
return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
@@ -106,6 +110,45 @@
EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 92}));
}
+// Send in a white image, expect a white pixel.
+TEST(QuantizedPoolingOpTest, AveragePoolImageSize16) {
+ int image_size = 16;
+ QuantizedPoolingOpModel m(
+ BuiltinOperator_AVERAGE_POOL_2D,
+ /*input=*/{TensorType_UINT8, {1, image_size, image_size, 1}, 0, 16},
+ /*filter_width=*/image_size,
+ /*filter_height=*/image_size,
+ /*output=*/{TensorType_UINT8, {}, 0, 16});
+
+ std::vector<float> input(image_size * image_size, 16.f);
+ m.SetInput(input);
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ::testing::ElementsAre(255));
+ EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({16})));
+}
+
+// Send in a white image, expect something other than a white pixel, due to
+// overflow.
+TEST(QuantizedPoolingOpTest, AveragePoolImageSize17) {
+ int image_size = 17;
+ QuantizedPoolingOpModel m(
+ BuiltinOperator_AVERAGE_POOL_2D,
+ /*input=*/{TensorType_UINT8, {1, image_size, image_size, 1}, 0, 16},
+ /*filter_width=*/image_size,
+ /*filter_height=*/image_size,
+ /*output=*/{TensorType_UINT8, {}, 0, 16});
+
+ std::vector<float> input(image_size * image_size, 16.f);
+ m.SetInput(input);
+ m.Invoke();
+
+ // Ordinarily we would see '255' here. However, the optimized version of
+ // AveragePool uses a uint16 accumulator which causes it to overflow for
+ // images this large.
+ EXPECT_THAT(m.GetOutput(), ::testing::ElementsAre(28));
+}
+
TEST(FloatPoolingOpTest, MaxPool) {
FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D,
/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
diff --git a/tensorflow/lite/kernels/reduce.cc b/tensorflow/lite/kernels/reduce.cc
index ed2d475..336e827 100644
--- a/tensorflow/lite/kernels/reduce.cc
+++ b/tensorflow/lite/kernels/reduce.cc
@@ -20,6 +20,8 @@
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
@@ -229,6 +231,17 @@
return ResizeTempSum(context, &op_context, temp_sum);
}
+void ResolveAxis(const int* axis_data, int axis_count,
+ tflite::MeanParams* op_params) {
+ int i = 0;
+ for (; i < axis_count; ++i) {
+ op_params->axis[i] = static_cast<int16>(axis_data[i]);
+ }
+ for (; i < 4; ++i) {
+ op_params->axis[i] = 1;
+ }
+}
+
template <KernelType kernel_type>
TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
OpContext op_context(context, node);
@@ -257,9 +270,23 @@
if (kernel_type == kReference) {
switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, float, float));
- break;
+ case kTfLiteFloat32: {
+ tflite::MeanParams op_params;
+ op_params.axis_count = num_axis;
+ ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
+ const TfLiteTensor* input = op_context.input;
+ if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
+ op_params.axis_count == 2 &&
+ ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+ (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
+ reference_ops::Mean(op_params, GetTensorShape(input),
+ GetTensorData<float>(input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
+ } else {
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, float, float));
+ }
+ } break;
case kTfLiteInt32:
TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int, int64_t));
break;
@@ -286,7 +313,8 @@
GetTensorData<int>(op_context.axis), num_axis,
op_context.params->keep_dims, GetTensorData<int>(temp_index),
GetTensorData<int>(resolved_axis),
- GetTensorData<int>(temp_sum), /*compute_sum=*/false));
+ GetTensorData<int>(temp_sum),
+ /*compute_sum=*/false));
}
break;
default:
diff --git a/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java
index d5b1ac0..fbd7505 100644
--- a/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java
+++ b/tensorflow/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java
@@ -90,29 +90,26 @@
}
private MappedByteBuffer loadModelFile() throws IOException {
- AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
- FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
- try {
+ try (AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH);
+ FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor())) {
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- } finally {
- inputStream.close();
}
}
private String[] loadBackoffList() throws IOException {
List<String> labelList = new ArrayList<String>();
- BufferedReader reader =
- new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH)));
- String line;
- while ((line = reader.readLine()) != null) {
- if (!line.isEmpty()) {
- labelList.add(line);
+ try (BufferedReader reader =
+ new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH)))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ if (!line.isEmpty()) {
+ labelList.add(line);
+ }
}
}
- reader.close();
String[] ans = new String[labelList.size()];
labelList.toArray(ans);
return ans;
diff --git a/tensorflow/lite/nnapi_delegate.cc b/tensorflow/lite/nnapi_delegate.cc
index 292dedf..58288a8 100644
--- a/tensorflow/lite/nnapi_delegate.cc
+++ b/tensorflow/lite/nnapi_delegate.cc
@@ -683,6 +683,7 @@
case tflite::BuiltinOperator_RANGE:
case tflite::BuiltinOperator_LEAKY_RELU:
case tflite::BuiltinOperator_SQUARED_DIFFERENCE:
+ case tflite::BuiltinOperator_MIRROR_PAD:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py
index 3afce1b..718b230 100644
--- a/tensorflow/lite/python/op_hint.py
+++ b/tensorflow/lite/python/op_hint.py
@@ -403,7 +403,7 @@
out_graphdef: A graphdef that is ready to have this input added.
Returns:
- The the output that the stub should use as an input for this operand.
+ The output that the stub should use as an input for this operand.
Raises:
RuntimeError: if the method is not implemented.
diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs
index e40a040..652871d 100644
--- a/tensorflow/lite/schema/schema.fbs
+++ b/tensorflow/lite/schema/schema.fbs
@@ -202,6 +202,7 @@
RESIZE_NEAREST_NEIGHBOR = 97,
LEAKY_RELU = 98,
SQUARED_DIFFERENCE = 99,
+ MIRROR_PAD = 100,
}
// Options for the builtin operators.
@@ -282,6 +283,7 @@
ResizeNearestNeighborOptions,
LeakyReluOptions,
SquaredDifferenceOptions,
+ MirrorPadOptions,
}
enum Padding : byte { SAME, VALID }
@@ -669,6 +671,17 @@
table SquaredDifferenceOptions {
}
+enum MirrorPadMode : byte {
+ // Doesn't include borders.
+ REFLECT = 0,
+ // Includes borders.
+ SYMMETRIC = 1,
+}
+
+table MirrorPadOptions {
+ mode:MirrorPadMode;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h
index e93cb3d..1464c75 100755
--- a/tensorflow/lite/schema/schema_generated.h
+++ b/tensorflow/lite/schema/schema_generated.h
@@ -259,6 +259,9 @@
struct SquaredDifferenceOptions;
struct SquaredDifferenceOptionsT;
+struct MirrorPadOptions;
+struct MirrorPadOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -508,11 +511,12 @@
BuiltinOperator_RESIZE_NEAREST_NEIGHBOR = 97,
BuiltinOperator_LEAKY_RELU = 98,
BuiltinOperator_SQUARED_DIFFERENCE = 99,
+ BuiltinOperator_MIRROR_PAD = 100,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_SQUARED_DIFFERENCE
+ BuiltinOperator_MAX = BuiltinOperator_MIRROR_PAD
};
-inline const BuiltinOperator (&EnumValuesBuiltinOperator())[99] {
+inline const BuiltinOperator (&EnumValuesBuiltinOperator())[100] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -612,7 +616,8 @@
BuiltinOperator_RANGE,
BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
BuiltinOperator_LEAKY_RELU,
- BuiltinOperator_SQUARED_DIFFERENCE
+ BuiltinOperator_SQUARED_DIFFERENCE,
+ BuiltinOperator_MIRROR_PAD
};
return values;
}
@@ -719,6 +724,7 @@
"RESIZE_NEAREST_NEIGHBOR",
"LEAKY_RELU",
"SQUARED_DIFFERENCE",
+ "MIRROR_PAD",
nullptr
};
return names;
@@ -807,11 +813,12 @@
BuiltinOptions_ResizeNearestNeighborOptions = 74,
BuiltinOptions_LeakyReluOptions = 75,
BuiltinOptions_SquaredDifferenceOptions = 76,
+ BuiltinOptions_MirrorPadOptions = 77,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_SquaredDifferenceOptions
+ BuiltinOptions_MAX = BuiltinOptions_MirrorPadOptions
};
-inline const BuiltinOptions (&EnumValuesBuiltinOptions())[77] {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[78] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -889,7 +896,8 @@
BuiltinOptions_RangeOptions,
BuiltinOptions_ResizeNearestNeighborOptions,
BuiltinOptions_LeakyReluOptions,
- BuiltinOptions_SquaredDifferenceOptions
+ BuiltinOptions_SquaredDifferenceOptions,
+ BuiltinOptions_MirrorPadOptions
};
return values;
}
@@ -973,6 +981,7 @@
"ResizeNearestNeighborOptions",
"LeakyReluOptions",
"SquaredDifferenceOptions",
+ "MirrorPadOptions",
nullptr
};
return names;
@@ -1291,6 +1300,10 @@
static const BuiltinOptions enum_value = BuiltinOptions_SquaredDifferenceOptions;
};
+template<> struct BuiltinOptionsTraits<MirrorPadOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_MirrorPadOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1930,6 +1943,14 @@
return type == BuiltinOptions_SquaredDifferenceOptions ?
reinterpret_cast<const SquaredDifferenceOptionsT *>(value) : nullptr;
}
+ MirrorPadOptionsT *AsMirrorPadOptions() {
+ return type == BuiltinOptions_MirrorPadOptions ?
+ reinterpret_cast<MirrorPadOptionsT *>(value) : nullptr;
+ }
+ const MirrorPadOptionsT *AsMirrorPadOptions() const {
+ return type == BuiltinOptions_MirrorPadOptions ?
+ reinterpret_cast<const MirrorPadOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -2127,6 +2148,35 @@
return EnumNamesCombinerType()[index];
}
+enum MirrorPadMode {
+ MirrorPadMode_REFLECT = 0,
+ MirrorPadMode_SYMMETRIC = 1,
+ MirrorPadMode_MIN = MirrorPadMode_REFLECT,
+ MirrorPadMode_MAX = MirrorPadMode_SYMMETRIC
+};
+
+inline const MirrorPadMode (&EnumValuesMirrorPadMode())[2] {
+ static const MirrorPadMode values[] = {
+ MirrorPadMode_REFLECT,
+ MirrorPadMode_SYMMETRIC
+ };
+ return values;
+}
+
+inline const char * const *EnumNamesMirrorPadMode() {
+ static const char * const names[] = {
+ "REFLECT",
+ "SYMMETRIC",
+ nullptr
+ };
+ return names;
+}
+
+inline const char *EnumNameMirrorPadMode(MirrorPadMode e) {
+ const size_t index = static_cast<int>(e);
+ return EnumNamesMirrorPadMode()[index];
+}
+
enum CustomOptionsFormat {
CustomOptionsFormat_FLEXBUFFERS = 0,
CustomOptionsFormat_MIN = CustomOptionsFormat_FLEXBUFFERS,
@@ -6769,6 +6819,60 @@
flatbuffers::Offset<SquaredDifferenceOptions> CreateSquaredDifferenceOptions(flatbuffers::FlatBufferBuilder &_fbb, const SquaredDifferenceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct MirrorPadOptionsT : public flatbuffers::NativeTable {
+ typedef MirrorPadOptions TableType;
+ MirrorPadMode mode;
+ MirrorPadOptionsT()
+ : mode(MirrorPadMode_REFLECT) {
+ }
+};
+
+struct MirrorPadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef MirrorPadOptionsT NativeTableType;
+ enum {
+ VT_MODE = 4
+ };
+ MirrorPadMode mode() const {
+ return static_cast<MirrorPadMode>(GetField<int8_t>(VT_MODE, 0));
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_MODE) &&
+ verifier.EndTable();
+ }
+ MirrorPadOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(MirrorPadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<MirrorPadOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct MirrorPadOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_mode(MirrorPadMode mode) {
+ fbb_.AddElement<int8_t>(MirrorPadOptions::VT_MODE, static_cast<int8_t>(mode), 0);
+ }
+ explicit MirrorPadOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ MirrorPadOptionsBuilder &operator=(const MirrorPadOptionsBuilder &);
+ flatbuffers::Offset<MirrorPadOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<MirrorPadOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<MirrorPadOptions> CreateMirrorPadOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ MirrorPadMode mode = MirrorPadMode_REFLECT) {
+ MirrorPadOptionsBuilder builder_(_fbb);
+ builder_.add_mode(mode);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<MirrorPadOptions> CreateMirrorPadOptions(flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -7130,6 +7234,9 @@
const SquaredDifferenceOptions *builtin_options_as_SquaredDifferenceOptions() const {
return builtin_options_type() == BuiltinOptions_SquaredDifferenceOptions ? static_cast<const SquaredDifferenceOptions *>(builtin_options()) : nullptr;
}
+ const MirrorPadOptions *builtin_options_as_MirrorPadOptions() const {
+ return builtin_options_type() == BuiltinOptions_MirrorPadOptions ? static_cast<const MirrorPadOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -7465,6 +7572,10 @@
return builtin_options_as_SquaredDifferenceOptions();
}
+template<> inline const MirrorPadOptions *Operator::builtin_options_as<MirrorPadOptions>() const {
+ return builtin_options_as_MirrorPadOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -10005,6 +10116,32 @@
_fbb);
}
+inline MirrorPadOptionsT *MirrorPadOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new MirrorPadOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void MirrorPadOptions::UnPackTo(MirrorPadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = mode(); _o->mode = _e; };
+}
+
+inline flatbuffers::Offset<MirrorPadOptions> MirrorPadOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateMirrorPadOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<MirrorPadOptions> CreateMirrorPadOptions(flatbuffers::FlatBufferBuilder &_fbb, const MirrorPadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MirrorPadOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _mode = _o->mode;
+ return tflite::CreateMirrorPadOptions(
+ _fbb,
+ _mode);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -10567,6 +10704,10 @@
auto ptr = reinterpret_cast<const SquaredDifferenceOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_MirrorPadOptions: {
+ auto ptr = reinterpret_cast<const MirrorPadOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -10889,6 +11030,10 @@
auto ptr = reinterpret_cast<const SquaredDifferenceOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_MirrorPadOptions: {
+ auto ptr = reinterpret_cast<const MirrorPadOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -11199,6 +11344,10 @@
auto ptr = reinterpret_cast<const SquaredDifferenceOptionsT *>(value);
return CreateSquaredDifferenceOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_MirrorPadOptions: {
+ auto ptr = reinterpret_cast<const MirrorPadOptionsT *>(value);
+ return CreateMirrorPadOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -11509,6 +11658,10 @@
value = new SquaredDifferenceOptionsT(*reinterpret_cast<SquaredDifferenceOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_MirrorPadOptions: {
+ value = new MirrorPadOptionsT(*reinterpret_cast<MirrorPadOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -11896,6 +12049,11 @@
delete ptr;
break;
}
+ case BuiltinOptions_MirrorPadOptions: {
+ auto ptr = reinterpret_cast<MirrorPadOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py
index 9b0f59f..b143f45 100644
--- a/tensorflow/lite/testing/generate_examples.py
+++ b/tensorflow/lite/testing/generate_examples.py
@@ -905,40 +905,46 @@
def f(zip_path):
"""Actual function that generates examples."""
- test_parameters = [{
- "input_dtype": [tf.float32, tf.int32, tf.int64],
- "input_shape": [[3, 2, 4]],
- "axis": [
- 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0],
- [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0],
- [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3]
- ],
- "const_axis": [True, False],
- "keepdims": [True, False],
- }, {
- "input_dtype": [tf.float32],
- "input_shape": [[1, 8, 8, 3]],
- "axis": [
- 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3],
- [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2,
- -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2],
- [2, 2, 3], [-3, -3, -4], [-3, 2, 1]
- ],
- "const_axis": [True, False],
- "keepdims": [True, False],
- }, {
- "input_dtype": [tf.float32],
- "input_shape": [[], [1, 8, 8, 3], [3, 2, 4]],
- "axis": [[]], # shape is: [0]
- "const_axis": [False],
- "keepdims": [True, False],
- }, {
- "input_dtype": [tf.float32],
- "input_shape": [[], [1, 8, 8, 3], [3, 2, 4]],
- "axis": [None], # shape is: []
- "const_axis": [True],
- "keepdims": [True, False],
- }]
+ test_parameters = [
+ {
+ "input_dtype": [tf.float32, tf.int32, tf.int64],
+ "input_shape": [[3, 3, 2, 4]],
+ "axis": [
+ 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0],
+ [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1],
+ [-1, 0], [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3]
+ ],
+ "const_axis": [True, False],
+ "keepdims": [True, False],
+ },
+ {
+ "input_dtype": [tf.float32],
+ "input_shape": [[1, 8, 8, 3]],
+ "axis": [
+ 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2,
+ 3], [3, 2, 1, 0],
+ [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, -3, -4,
+ [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2],
+ [2, 2, 3], [-3, -3, -4], [-3, 2, 1]
+ ],
+ "const_axis": [True, False],
+ "keepdims": [True, False],
+ },
+ {
+ "input_dtype": [tf.float32],
+ "input_shape": [[], [1, 8, 8, 3], [3, 2, 4]],
+ "axis": [[]], # shape is: [0]
+ "const_axis": [False],
+ "keepdims": [True, False],
+ },
+ {
+ "input_dtype": [tf.float32],
+ "input_shape": [[], [1, 8, 8, 3], [3, 2, 4]],
+ "axis": [None], # shape is: []
+ "const_axis": [True],
+ "keepdims": [True, False],
+ }
+ ]
def build_graph(parameters):
"""Build the mean op testing graph."""
@@ -2520,6 +2526,32 @@
_make_strided_slice_tests(zip_path, test_parameters)
+def make_strided_slice_buggy_tests(zip_path):
+ """Make a set of tests to show strided_slice yields incorrect results."""
+
+ test_parameters = [{
+ "unused_iteration_counter": [1],
+ }]
+
+ def build_graph(parameters):
+ """Build the strided_slice op testing graph."""
+ del parameters
+ input_values = tf.placeholder(dtype=tf.float32, shape=[4, 2])
+ data = tf.constant([[0, 1, 2, 3],
+ [4, 5, 6, 7],
+ [8, 9, 10, 11],
+ [12, 13, 14, 15]], tf.float32)
+ return [input_values], [input_values + data[:, :2]]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ del parameters
+ input_values = np.zeros([4, 2], dtype=np.float32)
+ return [input_values], sess.run(
+ outputs, feed_dict={inputs[0]: input_values})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_lstm_tests(zip_path):
"""Make a set of tests to do basic Lstm cell."""
@@ -3496,6 +3528,33 @@
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_placeholder_with_default_tests(zip_path):
+ """Make a set of tests to test placeholder_with_default."""
+
+ test_parameters = [{
+ "dtype": [tf.float32, tf.int32, tf.int64],
+ }]
+
+ def build_graph(parameters):
+ """Build the placeholder_with_default testing graph."""
+ const_node = tf.constant(
+ [1, 2, 2, 0], shape=[2, 2], dtype=parameters["dtype"])
+ input_tensor = tf.placeholder_with_default(
+ const_node, shape=[2, 2], name="input")
+ out = tf.equal(input_tensor, const_node, name="output")
+
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ numpy_type = _TF_TYPE_INFO[parameters["dtype"]][0]
+ input_value = np.array([[1, 0], [2, 1]], numpy_type)
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs,
+ expected_tf_success=3)
+
+
# Toco binary path provided by the generate rule.
bin_path = None
diff --git a/tensorflow/lite/testing/generated_examples_zip_test.cc b/tensorflow/lite/testing/generated_examples_zip_test.cc
index 6f31daa..91a4851 100644
--- a/tensorflow/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/lite/testing/generated_examples_zip_test.cc
@@ -102,6 +102,9 @@
{R"(^\/add.*dtype=tf\.int64)", "119126484"},
{R"(^\/floor_div.*dtype=tf\.int64)", "119126484"},
{R"(^\/squared_difference.*dtype=tf\.int64)", "119126484"},
+
+ // Strided Slice chooses the wrong dimension.
+ {R"(^\/strided_slice_buggy)", "119786029"},
};
// Allows test data to be unarchived into a temporary directory and makes
diff --git a/tensorflow/lite/testing/tflite_driver.cc b/tensorflow/lite/testing/tflite_driver.cc
index 3a0febb..27e3a37 100644
--- a/tensorflow/lite/testing/tflite_driver.cc
+++ b/tensorflow/lite/testing/tflite_driver.cc
@@ -147,9 +147,10 @@
}
TfLiteDriver::~TfLiteDriver() {
- for (TfLiteTensor* t : tensors_to_deallocate_) {
- free(t->data.raw);
+ for (auto t : tensors_to_deallocate_) {
+ DeallocateStringTensor(t.second);
}
+ interpreter_.reset();
}
void TfLiteDriver::AllocateTensors() {
@@ -242,12 +243,10 @@
case kTfLiteString: {
string s = absl::HexStringToBytes(csv_values);
- tensor->data.raw = reinterpret_cast<char*>(malloc(s.size()));
- tensor->bytes = s.size();
+ DeallocateStringTensor(tensors_to_deallocate_[id]);
+ AllocateStringTensor(id, s.size(), tensor);
memcpy(tensor->data.raw, s.data(), s.size());
- // We must remember to free the memory we allocated above.
- tensors_to_deallocate_.push_back(tensor);
break;
}
default:
diff --git a/tensorflow/lite/testing/tflite_driver.h b/tensorflow/lite/testing/tflite_driver.h
index d8b4056..1da0533 100644
--- a/tensorflow/lite/testing/tflite_driver.h
+++ b/tensorflow/lite/testing/tflite_driver.h
@@ -49,6 +49,18 @@
string ReadOutput(int id) override { return "no-op"; }
private:
+ void DeallocateStringTensor(TfLiteTensor* t) {
+ if (t) {
+ free(t->data.raw);
+ t->data.raw = nullptr;
+ }
+ }
+ void AllocateStringTensor(int id, size_t num_bytes, TfLiteTensor* t) {
+ t->data.raw = reinterpret_cast<char*>(malloc(num_bytes));
+ t->bytes = num_bytes;
+ tensors_to_deallocate_[id] = t;
+ }
+
void ResetLSTMStateTensors();
class Expectation;
@@ -59,7 +71,7 @@
std::unique_ptr<Interpreter> interpreter_;
std::map<int, std::unique_ptr<Expectation>> expected_output_;
bool must_allocate_tensors_ = true;
- std::vector<TfLiteTensor*> tensors_to_deallocate_;
+ std::map<int, TfLiteTensor*> tensors_to_deallocate_;
};
} // namespace testing
diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD
index 1430287..82aa1f5 100644
--- a/tensorflow/lite/toco/BUILD
+++ b/tensorflow/lite/toco/BUILD
@@ -395,9 +395,10 @@
# :toco is the main public command-line tool exposing the functionality
# of the :toco_tooling library.
-tf_cc_binary(
- name = "toco",
- srcs = ["toco.cc"],
+cc_library(
+ name = "toco_convert",
+ srcs = ["toco_convert.cc"],
+ hdrs = ["toco_convert.h"],
visibility = ["//visibility:public"],
deps = [
":model",
@@ -416,6 +417,51 @@
],
)
+tf_cc_binary(
+ name = "toco",
+ srcs = ["toco.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model",
+ ":model_cmdline_flags",
+ ":model_flags_proto_cc",
+ ":toco_cmdline_flags",
+ ":toco_convert",
+ ":toco_flags_proto_cc",
+ ":toco_port",
+ ":toco_tooling",
+ ":types_proto_cc",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/core:lib",
+ # We cannot embed the core:ops dependency directly into :toco_tooling as
+ # it can conflict with downstream deps when toco is used as a library.
+ "//tensorflow/core:ops",
+ ],
+)
+
+tf_cc_test(
+ name = "toco_convert_test",
+ srcs = ["toco_convert_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":model",
+ ":model_cmdline_flags",
+ ":model_flags_proto_cc",
+ ":toco_cmdline_flags",
+ ":toco_convert",
+ ":toco_flags_proto_cc",
+ ":toco_port",
+ ":toco_tooling",
+ ":types_proto_cc",
+ "@com_google_googletest//:gtest_main",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/core:lib",
+ # We cannot embed the core:ops dependency directly into :toco_tooling as
+ # it can conflict with downstream deps when toco is used as a library.
+ "//tensorflow/core:ops",
+ ],
+)
+
tf_cc_test(
name = "toco_port_test",
srcs = ["toco_port_test.cc"],
diff --git a/tensorflow/lite/toco/README.md b/tensorflow/lite/toco/README.md
index bd8f828..fe98a90 100644
--- a/tensorflow/lite/toco/README.md
+++ b/tensorflow/lite/toco/README.md
@@ -26,4 +26,4 @@
interpreter handles them on-device. This flow is represented in the diagram
below.
-
+
diff --git a/tensorflow/lite/toco/export_tensorflow.cc b/tensorflow/lite/toco/export_tensorflow.cc
index 1752745..bdc3a5b 100644
--- a/tensorflow/lite/toco/export_tensorflow.cc
+++ b/tensorflow/lite/toco/export_tensorflow.cc
@@ -48,7 +48,8 @@
namespace toco {
namespace {
-tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type) {
+tensorflow::DataType GetTensorFlowDataType(ArrayDataType data_type,
+ const string& error_location) {
switch (data_type) {
case ArrayDataType::kBool:
return tensorflow::DT_BOOL;
@@ -66,14 +67,21 @@
return tensorflow::DT_COMPLEX64;
default:
case ArrayDataType::kNone:
- LOG(FATAL) << "Unsupported data type: " << static_cast<int>(data_type);
+ LOG(FATAL) << "Unsupported data type '" << ArrayDataTypeName(data_type)
+ << "' in " << error_location;
return tensorflow::DT_INVALID;
}
}
+tensorflow::DataType GetTensorFlowDataTypeForOp(ArrayDataType data_type,
+ const string& op_name) {
+ return GetTensorFlowDataType(data_type, "op '" + op_name + "'");
+}
+
tensorflow::DataType GetTensorFlowDataType(const Model& model,
const string& array_name) {
- return GetTensorFlowDataType(model.GetArray(array_name).data_type);
+ return GetTensorFlowDataType(model.GetArray(array_name).data_type,
+ "array '" + array_name + "'");
}
// TensorFlow sometimes forbids what it calls "legacy scalars",
@@ -1285,7 +1293,7 @@
*range_op->add_input() = src_op.inputs[1];
*range_op->add_input() = src_op.inputs[2];
(*range_op->mutable_attr())["Tidx"].set_type(
- GetTensorFlowDataType(src_op.dtype));
+ GetTensorFlowDataTypeForOp(src_op.dtype, /*op_name=*/src_op.outputs[0]));
}
void ConvertPackOperator(const Model& model, const PackOperator& src_op,
@@ -1298,7 +1306,8 @@
}
(*pack_op->mutable_attr())["axis"].set_i(src_op.axis);
(*pack_op->mutable_attr())["N"].set_i(src_op.inputs.size());
- (*pack_op->mutable_attr())["T"].set_type(GetTensorFlowDataType(src_op.dtype));
+ (*pack_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
}
void ConvertFillOperator(const Model& model, const FillOperator& src_op,
@@ -1887,7 +1896,7 @@
GetTensorFlowDataType(model, src_op.inputs[0]);
(*new_op->mutable_attr())["T"].set_type(shape_type);
(*new_op->mutable_attr())["dtype"].set_type(
- GetTensorFlowDataType(src_op.dtype));
+ GetTensorFlowDataTypeForOp(src_op.dtype, src_op.outputs[0]));
(*new_op->mutable_attr())["seed"].set_i(src_op.seed);
(*new_op->mutable_attr())["seed2"].set_i(src_op.seed2);
}
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index c51031b..b51f80c 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -219,7 +219,10 @@
// allocation code gets a bit confused. It seems that the code expects an
// empty shape for zero-sized shapes, so we will do just that, except for the
// [0] case.
- if (zero_sized_shape && input_dims_only_sizes.size() > 1) {
+ // TODO(b/119325030): In order to correctly import the "scalar" shapes the
+ // following test must include "&& input_dims_only_sizes.size() > 1", but
+ // that seems to slow everything down a lot.
+ if (zero_sized_shape) {
shape->mutable_dims()->clear();
if (input_flat_size != nullptr) *input_flat_size = 0;
return tensorflow::Status::OK();
@@ -1218,7 +1221,7 @@
void GetOutputTypesFromNodeDef(const NodeDef& node,
const tensorflow::OpDef& op_def,
TensorFlowUnsupportedOperator* op) {
- // The the given type to the op, or clear the types if invalid.
+ // The given type to the op, or clear the types if invalid.
auto add_type = [&node, op](tensorflow::DataType type) {
if (type == tensorflow::DT_INVALID) {
LOG(WARNING) << "Op node missing output type attribute: " << node.name();
@@ -2009,13 +2012,13 @@
tensorflow::SessionOptions options;
auto* device_count = options.config.mutable_device_count();
device_count->insert({"CPU", 1});
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<tensorflow::Device>> devices;
TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
options, "/job:localhost/replica:0/task:0", &devices));
tensorflow::FunctionLibraryDefinition fld(tensorflow::OpRegistry::Global(),
graphdef_copy.library());
- tensorflow::DeviceMgr device_mgr(devices);
+ tensorflow::DeviceMgr device_mgr(std::move(devices));
tensorflow::OptimizerOptions o_opts;
tensorflow::ProcessFunctionLibraryRuntime pflr(
&device_mgr, tensorflow::Env::Default(), TF_GRAPH_DEF_VERSION, &fld,
diff --git a/tensorflow/lite/toco/import_tensorflow_test.cc b/tensorflow/lite/toco/import_tensorflow_test.cc
index 07b52d3..0be358b 100644
--- a/tensorflow/lite/toco/import_tensorflow_test.cc
+++ b/tensorflow/lite/toco/import_tensorflow_test.cc
@@ -190,7 +190,9 @@
EXPECT_TRUE(ImportNode(node, &model).ok());
const auto& array = model.GetArray("Node1");
- EXPECT_THAT(array.shape().dims(), ::testing::ElementsAre(0));
+ // We would like to have [0] shapes actually import correctly, but
+ // for some reason that slows everything down.
+ EXPECT_THAT(array.shape().dims(), ::testing::ElementsAre());
}
TEST_P(ShapeImportTest, ShapeElementTooLarge) {
diff --git a/tensorflow/lite/toco/toco.cc b/tensorflow/lite/toco/toco.cc
index 9740015..4a3d6a5 100644
--- a/tensorflow/lite/toco/toco.cc
+++ b/tensorflow/lite/toco/toco.cc
@@ -16,87 +16,9 @@
#include <memory>
#include <string>
-#include "absl/strings/string_view.h"
-#include "tensorflow/lite/toco/model.h"
#include "tensorflow/lite/toco/model_cmdline_flags.h"
-#include "tensorflow/lite/toco/model_flags.pb.h"
#include "tensorflow/lite/toco/toco_cmdline_flags.h"
-#include "tensorflow/lite/toco/toco_flags.pb.h"
-#include "tensorflow/lite/toco/toco_port.h"
-#include "tensorflow/lite/toco/toco_tooling.h"
-#include "tensorflow/lite/toco/toco_types.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace toco {
-namespace {
-
-// Checks the permissions of the output file to ensure it is writeable.
-void CheckOutputFilePermissions(const Arg<string>& output_file) {
- QCHECK(output_file.specified()) << "Missing required flag --output_file.\n";
- QCHECK(port::file::Writable(output_file.value()).ok())
- << "Specified output_file is not writable: " << output_file.value()
- << ".\n";
-}
-
-// Checks the permissions of the frozen model file.
-void CheckFrozenModelPermissions(const Arg<string>& input_file) {
- QCHECK(input_file.specified()) << "Missing required flag --input_file.\n";
- QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok())
- << "Specified input_file does not exist: " << input_file.value() << ".\n";
- QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok())
- << "Specified input_file exists, but is not readable: "
- << input_file.value() << ".\n";
-}
-
-// Reads the contents of the GraphDef from either the frozen graph file or the
-// SavedModel directory. If it reads the SavedModel directory, it updates the
-// ModelFlags and TocoFlags accordingly.
-void ReadInputData(const ParsedTocoFlags& parsed_toco_flags,
- const ParsedModelFlags& parsed_model_flags,
- TocoFlags* toco_flags, ModelFlags* model_flags,
- string* graph_def_contents) {
- port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n");
-
- // Ensure savedmodel_directory is not set.
- QCHECK(!parsed_toco_flags.savedmodel_directory.specified())
- << "Use `tensorflow/lite/python/tflite_convert` script with "
- << "SavedModel directories.\n";
-
- // Checks the input file permissions and reads the contents.
- CheckFrozenModelPermissions(parsed_toco_flags.input_file);
- CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
- graph_def_contents, port::file::Defaults())
- .ok());
-}
-
-tensorflow::Status ToolMain(const ParsedTocoFlags& parsed_toco_flags,
- const ParsedModelFlags& parsed_model_flags) {
- ModelFlags model_flags;
- ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags);
-
- TocoFlags toco_flags;
- ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
-
- string graph_def_contents;
- ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags,
- &model_flags, &graph_def_contents);
- CheckOutputFilePermissions(parsed_toco_flags.output_file);
-
- std::unique_ptr<Model> model =
- Import(toco_flags, model_flags, graph_def_contents);
- Transform(toco_flags, model.get());
- string output_file_contents;
- TF_RETURN_IF_ERROR(Export(toco_flags, *model, toco_flags.allow_custom_ops(),
- &output_file_contents));
- TF_RETURN_IF_ERROR(
- port::file::SetContents(parsed_toco_flags.output_file.value(),
- output_file_contents, port::file::Defaults()));
- return tensorflow::Status();
-}
-
-} // namespace
-} // namespace toco
+#include "tensorflow/lite/toco/toco_convert.h"
int main(int argc, char** argv) {
toco::string msg;
@@ -126,6 +48,6 @@
return 1;
}
toco::port::InitGoogle(argv[0], effective_argc, &effective_argv, true);
- auto status = toco::ToolMain(parsed_toco_flags, parsed_model_flags);
+ auto status = toco::Convert(parsed_toco_flags, parsed_model_flags);
return status.ok() ? 0 : -1;
}
diff --git a/tensorflow/lite/toco/toco_convert.cc b/tensorflow/lite/toco/toco_convert.cc
new file mode 100644
index 0000000..28e7b10
--- /dev/null
+++ b/tensorflow/lite/toco/toco_convert.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdio>
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/lite/toco/model.h"
+#include "tensorflow/lite/toco/model_cmdline_flags.h"
+#include "tensorflow/lite/toco/model_flags.pb.h"
+#include "tensorflow/lite/toco/toco_cmdline_flags.h"
+#include "tensorflow/lite/toco/toco_flags.pb.h"
+#include "tensorflow/lite/toco/toco_port.h"
+#include "tensorflow/lite/toco/toco_tooling.h"
+#include "tensorflow/lite/toco/toco_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+namespace {
+
+// Checks the permissions of the output file to ensure it is writeable.
+void CheckOutputFilePermissions(const Arg<string>& output_file) {
+ QCHECK(output_file.specified()) << "Missing required flag --output_file.\n";
+ QCHECK(port::file::Writable(output_file.value()).ok())
+ << "Specified output_file is not writable: " << output_file.value()
+ << ".\n";
+}
+
+// Checks the permissions of the frozen model file.
+void CheckFrozenModelPermissions(const Arg<string>& input_file) {
+ QCHECK(input_file.specified()) << "Missing required flag --input_file.\n";
+ QCHECK(port::file::Exists(input_file.value(), port::file::Defaults()).ok())
+ << "Specified input_file does not exist: " << input_file.value() << ".\n";
+ QCHECK(port::file::Readable(input_file.value(), port::file::Defaults()).ok())
+ << "Specified input_file exists, but is not readable: "
+ << input_file.value() << ".\n";
+}
+
+// Reads the contents of the GraphDef from either the frozen graph file or the
+// SavedModel directory. If it reads the SavedModel directory, it updates the
+// ModelFlags and TocoFlags accordingly.
+void ReadInputData(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags,
+ TocoFlags* toco_flags, ModelFlags* model_flags,
+ string* graph_def_contents) {
+ port::CheckInitGoogleIsDone("InitGoogle is not done yet.\n");
+
+ // Ensure savedmodel_directory is not set.
+ QCHECK(!parsed_toco_flags.savedmodel_directory.specified())
+ << "Use `tensorflow/lite/python/tflite_convert` script with "
+ << "SavedModel directories.\n";
+
+ // Checks the input file permissions and reads the contents.
+ CheckFrozenModelPermissions(parsed_toco_flags.input_file);
+ CHECK(port::file::GetContents(parsed_toco_flags.input_file.value(),
+ graph_def_contents, port::file::Defaults())
+ .ok());
+}
+} // namespace
+
+tensorflow::Status Convert(const string& graph_def_contents,
+ const TocoFlags& toco_flags,
+ const ModelFlags& model_flags,
+ string* output_file_contents) {
+ std::unique_ptr<Model> model =
+ Import(toco_flags, model_flags, graph_def_contents);
+ Transform(toco_flags, model.get());
+ return Export(toco_flags, *model, toco_flags.allow_custom_ops(),
+ output_file_contents);
+}
+
+tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags) {
+ ModelFlags model_flags;
+ ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags);
+
+ TocoFlags toco_flags;
+ ReadTocoFlagsFromCommandLineFlags(parsed_toco_flags, &toco_flags);
+
+ string graph_def_contents;
+ ReadInputData(parsed_toco_flags, parsed_model_flags, &toco_flags,
+ &model_flags, &graph_def_contents);
+ CheckOutputFilePermissions(parsed_toco_flags.output_file);
+
+ string output_file_contents;
+ TF_RETURN_IF_ERROR(Convert(graph_def_contents, toco_flags, model_flags,
+ &output_file_contents));
+
+ TF_RETURN_IF_ERROR(
+ port::file::SetContents(parsed_toco_flags.output_file.value(),
+ output_file_contents, port::file::Defaults()));
+ return tensorflow::Status();
+}
+
+} // namespace toco
diff --git a/tensorflow/lite/toco/toco_convert.h b/tensorflow/lite/toco/toco_convert.h
new file mode 100644
index 0000000..ebbd336
--- /dev/null
+++ b/tensorflow/lite/toco/toco_convert.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_
+#define TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/lite/toco/args.h"
+#include "tensorflow/lite/toco/model_flags.pb.h"
+#include "tensorflow/lite/toco/toco_flags.pb.h"
+
+namespace toco {
+
+tensorflow::Status Convert(const string& graph_def_contents,
+ const TocoFlags& toco_flags,
+ const ModelFlags& model_flags,
+ string* output_file_contents);
+
+tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags,
+ const ParsedModelFlags& parsed_model_flags);
+} // namespace toco
+
+#endif // TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_
diff --git a/tensorflow/lite/toco/toco_convert_test.cc b/tensorflow/lite/toco/toco_convert_test.cc
new file mode 100644
index 0000000..c3c440d
--- /dev/null
+++ b/tensorflow/lite/toco/toco_convert_test.cc
@@ -0,0 +1,173 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/toco/toco_convert.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace toco {
+namespace {
+
+TEST(TocoTest, MissingInputFile) {
+ ParsedTocoFlags toco_flags;
+ ParsedModelFlags model_flags;
+ EXPECT_DEATH(Convert(toco_flags, model_flags).ok(),
+ "Missing required flag --input_file");
+}
+
+TEST(TocoTest, BadInputFormat) {
+ TocoFlags toco_flags;
+ ModelFlags model_flags;
+
+ string input;
+ string output;
+
+ EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(),
+ "Unhandled input_format='FILE_FORMAT_UNKNOWN'");
+}
+
+TEST(TocoTest, MissingOuputArrays) {
+ TocoFlags toco_flags;
+ ModelFlags model_flags;
+
+ toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
+ string input;
+ string output;
+
+ EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(),
+ "This model does not define output arrays, so a --output_arrays "
+ "flag must be given on the command-line");
+}
+
+TEST(TocoTest, BadOutputArray) {
+ TocoFlags toco_flags;
+ ModelFlags model_flags;
+
+ toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
+ model_flags.add_output_arrays("output1");
+ string input;
+ string output;
+
+ EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(),
+ "Specified output array .output1. is not produced by any op "
+ "in this graph. Is it a typo. To silence this message, pass "
+ "this flag: allow_nonexistent_arrays");
+}
+
+TEST(TocoTest, BadOutputFormat) {
+ TocoFlags toco_flags;
+ ModelFlags model_flags;
+
+ toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
+ model_flags.add_output_arrays("output1");
+ string input = R"GraphDef(
+ node {
+ name: "output1"
+ input: "input1"
+ input: "input2"
+ op: "Sub"
+ attr { key: "T" value { type: DT_FLOAT } }
+ }
+ )GraphDef";
+
+ string output;
+
+ EXPECT_DEATH(Convert(input, toco_flags, model_flags, &output).ok(),
+ "Unhandled output_format='FILE_FORMAT_UNKNOWN'");
+}
+
+TEST(TocoTest, SimpleFloatModel) {
+ TocoFlags toco_flags;
+ ModelFlags model_flags;
+
+ toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
+ toco_flags.set_output_format(TENSORFLOW_GRAPHDEF);
+
+ // Inputs are automatically selected (but that might not be a good idea).
+ model_flags.add_output_arrays("output1");
+ string input = R"GraphDef(
+ node {
+ name: "input1"
+ op: "Placeholder"
+ attr { key: "dtype" value { type: DT_INT64 } }
+ }
+ node {
+ name: "input2"
+ op: "Placeholder"
+ attr { key: "dtype" value { type: DT_INT64 } }
+ }
+ node {
+ name: "output1"
+ input: "input1"
+ input: "input2"
+ op: "Sub"
+ attr { key: "T" value { type: DT_FLOAT } }
+ }
+ )GraphDef";
+
+ string output;
+ EXPECT_TRUE(Convert(input, toco_flags, model_flags, &output).ok());
+ EXPECT_TRUE(!output.empty());
+}
+
+TEST(TocoTest, TransientStringTensors) {
+ TocoFlags toco_flags;
+ ModelFlags model_flags;
+
+ toco_flags.set_input_format(TENSORFLOW_GRAPHDEF);
+
+ // We need to do a couple of things to trigger the transient array
+ // initialization code: output format must support memory planning, and the
+ // input array must have a shape.
+ toco_flags.set_output_format(TFLITE);
+
+ model_flags.add_output_arrays("output1");
+ string input = R"GraphDef(
+ node {
+ name: "input1"
+ op: "Placeholder"
+ attr { key: "dtype" value { type: DT_STRING } }
+ attr { key: "shape" value { shape { dim { size:1 }}}}
+ }
+ node {
+ name: "indices1"
+ op: "Placeholder"
+ attr { key: "dtype" value { type: DT_INT64 } }
+ }
+ node {
+ name: "intermediate1"
+ op: "Gather"
+ input: "input1"
+ input: "indices1"
+ attr { key: "Tparams" value { type: DT_STRING } }
+ attr { key: "Tindices" value { type: DT_INT64 } }
+ }
+ node {
+ name: "output1"
+ op: "Gather"
+ input: "intermediate1"
+ input: "indices2"
+ attr { key: "Tparams" value { type: DT_STRING } }
+ attr { key: "Tindices" value { type: DT_INT64 } }
+ }
+ )GraphDef";
+
+ string output;
+
+ EXPECT_TRUE(Convert(input, toco_flags, model_flags, &output).ok());
+ EXPECT_TRUE(!output.empty());
+}
+
+} // namespace
+} // namespace toco
diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc
index 5f96e83..d8b111d 100644
--- a/tensorflow/lite/toco/toco_tooling.cc
+++ b/tensorflow/lite/toco/toco_tooling.cc
@@ -210,7 +210,8 @@
CheckInvariants(*model);
break;
default:
- LOG(FATAL) << "Unhandled input_format";
+ LOG(FATAL) << "Unhandled input_format='"
+ << FileFormat_Name(toco_flags.input_format()) << "'";
}
LogDump(kLogLevelModelChanged, "AT IMPORT", *model);
@@ -424,7 +425,8 @@
DumpGraphviz(model, output_file_contents);
break;
default:
- LOG(FATAL) << "Unhandled output_format";
+ LOG(FATAL) << "Unhandled output_format='"
+ << FileFormat_Name(toco_flags.output_format()) << "'";
}
return tensorflow::Status();
}
diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc
index cff3877..611add9 100644
--- a/tensorflow/lite/toco/tooling_util.cc
+++ b/tensorflow/lite/toco/tooling_util.cc
@@ -1035,10 +1035,10 @@
if (colon_pos != string::npos) {
CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"),
string::npos)
- << "Array name must only have digits after colon";
+ << "Array '" << name << "' has non-digit characters after colon.";
}
- CHECK_GT(colon_pos, 0)
- << "First character of array name must not be a colon.";
+ CHECK_GT(colon_pos, 0) << "Array '" << name
+ << "' must not start with a colon.";
}
}
@@ -1770,6 +1770,14 @@
if (!array->has_shape()) {
return false;
}
+
+ // The size of string tensors is rarely known ahead of time, so all transient
+ // tensors of this type will need to be dynamically allocated.
+ if (array->final_data_type == ArrayDataType::kString ||
+ array->data_type == ArrayDataType::kString) {
+ return false;
+ }
+
return true;
}
diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.cc b/tensorflow/lite/tools/benchmark/benchmark_model.cc
index 05148ae..e9b485e 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_model.cc
@@ -51,11 +51,13 @@
BenchmarkParams BenchmarkModel::DefaultParams() {
BenchmarkParams params;
params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(50));
+ params.AddParam("min_secs", BenchmarkParam::Create<float>(1.0f));
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
params.AddParam("output_prefix", BenchmarkParam::Create<std::string>(""));
params.AddParam("warmup_runs", BenchmarkParam::Create<int32_t>(1));
+ params.AddParam("warmup_min_secs", BenchmarkParam::Create<float>(0.5f));
return params;
}
@@ -73,19 +75,34 @@
std::vector<Flag> BenchmarkModel::GetFlags() {
return {
- CreateFlag<int32_t>("num_runs", ¶ms_, "number of runs"),
+ CreateFlag<int32_t>("num_runs", ¶ms_,
+ "minimum number of runs, see also min_secs"),
+ CreateFlag<float>(
+ "min_secs", ¶ms_,
+ "minimum number of seconds to rerun for, potentially making the "
+ "actual number of runs to be greater than num_runs"),
CreateFlag<float>("run_delay", ¶ms_, "delay between runs in seconds"),
CreateFlag<int32_t>("num_threads", ¶ms_, "number of threads"),
CreateFlag<std::string>("benchmark_name", ¶ms_, "benchmark name"),
CreateFlag<std::string>("output_prefix", ¶ms_,
"benchmark output prefix"),
- CreateFlag<int32_t>("warmup_runs", ¶ms_,
- "how many runs to initialize model"),
+ CreateFlag<int32_t>(
+ "warmup_runs", ¶ms_,
+ "minimum number of runs performed on initialization, to "
+ "allow performance characteristics to settle, see also "
+ "warmup_min_secs"),
+ CreateFlag<float>(
+ "warmup_min_secs", ¶ms_,
+ "minimum number of seconds to rerun for, potentially making the "
+ "actual number of warm-up runs to be greater than warmup_runs"),
};
}
void BenchmarkModel::LogParams() {
- TFLITE_LOG(INFO) << "Num runs: [" << params_.Get<int32_t>("num_runs") << "]";
+ TFLITE_LOG(INFO) << "Min num runs: [" << params_.Get<int32_t>("num_runs")
+ << "]";
+ TFLITE_LOG(INFO) << "Min runs duration (seconds): ["
+ << params_.Get<float>("min_secs") << "]";
TFLITE_LOG(INFO) << "Inter-run delay (seconds): ["
<< params_.Get<float>("run_delay") << "]";
TFLITE_LOG(INFO) << "Num threads: [" << params_.Get<int32_t>("num_threads")
@@ -94,16 +111,24 @@
<< params_.Get<std::string>("benchmark_name") << "]";
TFLITE_LOG(INFO) << "Output prefix: ["
<< params_.Get<std::string>("output_prefix") << "]";
- TFLITE_LOG(INFO) << "Warmup runs: [" << params_.Get<int32_t>("warmup_runs")
- << "]";
+ TFLITE_LOG(INFO) << "Min warmup runs: ["
+ << params_.Get<int32_t>("warmup_runs") << "]";
+ TFLITE_LOG(INFO) << "Min warmup runs duration (seconds): ["
+ << params_.Get<float>("warmup_min_secs") << "]";
}
void BenchmarkModel::PrepareInputsAndOutputs() {}
-Stat<int64_t> BenchmarkModel::Run(int num_times, RunType run_type) {
+Stat<int64_t> BenchmarkModel::Run(int min_num_times, float min_secs,
+ RunType run_type) {
Stat<int64_t> run_stats;
- TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations ";
- for (int run = 0; run < num_times; run++) {
+ TFLITE_LOG(INFO) << "Running benchmark for at least " << min_num_times
+ << " iterations and at least " << min_secs << " seconds";
+ int64_t min_finish_us =
+ profiling::time::NowMicros() + static_cast<int64_t>(min_secs * 1.e6f);
+ for (int run = 0;
+ run < min_num_times || profiling::time::NowMicros() < min_finish_us;
+ run++) {
PrepareInputsAndOutputs();
listeners_.OnSingleRunStart(run_type);
int64_t start_us = profiling::time::NowMicros();
@@ -145,9 +170,11 @@
uint64_t input_bytes = ComputeInputBytes();
Stat<int64_t> warmup_time_us =
- Run(params_.Get<int32_t>("warmup_runs"), WARMUP);
+ Run(params_.Get<int32_t>("warmup_runs"),
+ params_.Get<float>("warmup_min_secs"), WARMUP);
Stat<int64_t> inference_time_us =
- Run(params_.Get<int32_t>("num_runs"), REGULAR);
+ Run(params_.Get<int32_t>("num_runs"), params_.Get<float>("min_secs"),
+ REGULAR);
listeners_.OnBenchmarkEnd(
{startup_latency_us, input_bytes, warmup_time_us, inference_time_us});
}
diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.h b/tensorflow/lite/tools/benchmark/benchmark_model.h
index d8a9b05..31ee5c9 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_model.h
@@ -150,7 +150,8 @@
bool ParseFlags(int argc, char** argv);
virtual std::vector<Flag> GetFlags();
virtual uint64_t ComputeInputBytes() = 0;
- virtual tensorflow::Stat<int64_t> Run(int num_times, RunType run_type);
+ virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
+ RunType run_type);
virtual void PrepareInputsAndOutputs();
virtual void RunImpl() = 0;
BenchmarkParams params_;
diff --git a/tensorflow/lite/tools/benchmark/benchmark_test.cc b/tensorflow/lite/tools/benchmark/benchmark_test.cc
index 59d23d9..8191fbc 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_test.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_test.cc
@@ -33,6 +33,7 @@
BenchmarkParams CreateParams() {
BenchmarkParams params;
params.AddParam("num_runs", BenchmarkParam::Create<int32_t>(2));
+ params.AddParam("min_secs", BenchmarkParam::Create<float>(1.0f));
params.AddParam("run_delay", BenchmarkParam::Create<float>(-1.0f));
params.AddParam("num_threads", BenchmarkParam::Create<int32_t>(1));
params.AddParam("benchmark_name", BenchmarkParam::Create<std::string>(""));
@@ -42,6 +43,7 @@
params.AddParam("input_layer", BenchmarkParam::Create<std::string>(""));
params.AddParam("input_layer_shape", BenchmarkParam::Create<std::string>(""));
params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(false));
+ params.AddParam("warmup_min_secs", BenchmarkParam::Create<float>(0.5f));
return params;
}
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 777d9dd..7768b75 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -181,7 +181,9 @@
return true;
}
-BenchmarkParams GetDefaultParams() {
+} // namespace
+
+BenchmarkParams BenchmarkTfLiteModel::DefaultParams() {
BenchmarkParams default_params = BenchmarkModel::DefaultParams();
default_params.AddParam("graph", BenchmarkParam::Create<std::string>(""));
default_params.AddParam("input_layer",
@@ -192,10 +194,8 @@
return default_params;
}
-} // namespace
-
BenchmarkTfLiteModel::BenchmarkTfLiteModel()
- : BenchmarkTfLiteModel(GetDefaultParams()) {}
+ : BenchmarkTfLiteModel(DefaultParams()) {}
BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params)
: BenchmarkModel(std::move(params)) {
@@ -319,6 +319,7 @@
bool use_nnapi = params_.Get<bool>("use_nnapi");
interpreter->UseNNAPI(use_nnapi);
+ ApplyDelegates();
auto interpreter_inputs = interpreter->inputs();
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
index 401ab54..83599e6 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
@@ -77,11 +77,16 @@
};
protected:
+ static BenchmarkParams DefaultParams();
void PrepareInputsAndOutputs() override;
- private:
+ // Allows installation of custom delegates during initialization
+ virtual void ApplyDelegates() {}
+
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
+
+ private:
std::vector<InputLayerInfo> inputs;
ProfilingListener profiling_listener_;
GemmlowpProfilingListener gemmlowp_profiling_listener_;
diff --git a/tensorflow/lite/tools/make/targets/ios_makefile.inc b/tensorflow/lite/tools/make/targets/ios_makefile.inc
index 7f36b8e..ae9276f 100644
--- a/tensorflow/lite/tools/make/targets/ios_makefile.inc
+++ b/tensorflow/lite/tools/make/targets/ios_makefile.inc
@@ -22,7 +22,7 @@
TARGET_ARCH := x86_64
CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
- -DTFLITE_USE_APPLE_ACCELERATE_FOR_CONV \
+ -DTF_LITE_USE_CBLAS \
-fembed-bitcode \
-Wno-c++11-narrowing \
-mno-thumb \
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 1010678..19d2af4 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -861,6 +861,7 @@
":platform",
":registry",
":tensor_shape",
+ ":tf2",
":traceable_stack",
":util",
":versions",
@@ -984,6 +985,7 @@
srcs_version = "PY2AND3",
deps = [
":dtypes",
+ ":tf2",
":util",
"//tensorflow/core:protos_all_py",
],
@@ -1081,10 +1083,12 @@
srcs_version = "PY2AND3",
deps = [
":client",
+ ":cond_v2",
":framework_test_lib",
":gradient_checker",
":platform_test",
":util",
+ ":while_v2",
],
)
@@ -2083,7 +2087,6 @@
srcs = ["ops/control_flow_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "tensor_shape",
":array_ops",
":array_ops_gen",
":constant_op",
@@ -2098,6 +2101,7 @@
":resource_variable_ops_gen",
":sparse_tensor",
":tensor_array_ops",
+ ":tensor_shape",
":tf2",
":tf_should_use",
":util",
@@ -2844,6 +2848,33 @@
)
py_library(
+ name = "sort_ops",
+ srcs = ["ops/sort_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":framework",
+ ":math_ops",
+ ":nn_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "sort_ops_test",
+ srcs = ["ops/sort_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":framework",
+ ":random_ops",
+ ":sort_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "spectral_ops_test_util",
srcs = ["ops/spectral_ops_test_util.py"],
srcs_version = "PY2AND3",
@@ -2957,6 +2988,7 @@
":random_ops",
":script_ops",
":session_ops",
+ ":sort_ops",
":sparse_grad",
":sparse_ops",
":special_math_ops",
@@ -3564,17 +3596,6 @@
)
py_library(
- name = "device_util",
- srcs = ["training/device_util.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":device",
- ":framework_ops",
- "//tensorflow/python/eager:context",
- ],
-)
-
-py_library(
name = "distribute",
srcs = [
"training/distribute.py",
@@ -3582,35 +3603,7 @@
],
srcs_version = "PY2AND3",
deps = [
- ":array_ops",
- ":constant_op",
- ":control_flow_ops",
- ":device_util",
- ":dtypes",
- ":framework_ops",
- ":platform",
- ":resource_variable_ops",
- ":state_ops",
- ":util",
- ":variable_scope",
- "//tensorflow/python/data",
- "//tensorflow/python/distribute:reduce_util",
- "//tensorflow/python/ops/losses",
- "//tensorflow/tools/docs:doc_controls",
- ],
-)
-
-py_test(
- name = "distribute_test",
- size = "small",
- srcs = ["training/distribute_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":client_testlib",
- ":constant_op",
- ":distribute",
- ":dtypes",
- ":variable_scope",
+ "//tensorflow/python/distribute:distribute_lib",
],
)
@@ -4599,7 +4592,6 @@
"training/basic_loops_test.py",
"training/coordinator_test.py",
"training/device_setter_test.py",
- "training/device_util_test.py",
"training/ftrl_test.py",
"training/gradient_descent_test.py",
"training/learning_rate_decay_test.py",
diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index ced2e47..3ac446d 100644
--- a/tensorflow/python/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -63,7 +63,6 @@
name = "asserts_test",
srcs = ["asserts_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":converters",
"//tensorflow/python:client_testlib",
@@ -239,7 +238,6 @@
name = "error_handlers_test",
srcs = ["error_handlers_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":converters",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py
index eef628a..803b6a0 100644
--- a/tensorflow/python/autograph/converters/asserts_test.py
+++ b/tensorflow/python/autograph/converters/asserts_test.py
@@ -41,7 +41,7 @@
op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message'):
- sess.run(op)
+ self.evaluate(op)
if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py
index 892f90e..9d76016 100644
--- a/tensorflow/python/autograph/converters/call_trees_test.py
+++ b/tensorflow/python/autograph/converters/call_trees_test.py
@@ -94,7 +94,7 @@
dtypes.int64) as result:
with self.cached_session() as sess:
self.assertTrue(isinstance(result.test_fn(), ops.Tensor))
- self.assertIn(sess.run(result.test_fn()), (0, 1, 2))
+ self.assertIn(self.evaluate(result.test_fn()), (0, 1, 2))
def test_uncompiled_modules(self):
diff --git a/tensorflow/python/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py
index 8c8135a..39843c7 100644
--- a/tensorflow/python/autograph/converters/lists_test.py
+++ b/tensorflow/python/autograph/converters/lists_test.py
@@ -123,7 +123,7 @@
with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
with self.cached_session() as sess:
- self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
+ self.assertAllEqual(self.evaluate(result.test_fn()), [1, 2, 3])
# TODO(mdan): Add a test with tf.stack with axis kwarg.
diff --git a/tensorflow/python/autograph/converters/side_effect_guards_test.py b/tensorflow/python/autograph/converters/side_effect_guards_test.py
index e72b5ea..f6d0f73 100644
--- a/tensorflow/python/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards_test.py
@@ -49,7 +49,7 @@
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
- sess.run(result.test_fn(v))
+ self.evaluate(result.test_fn(v))
# TODO(mdan): Add support for this use case.
# Right now the variable `a` is not conditioned on the `assign` because
# there's no way to add control dependencies to a variable object.
@@ -70,7 +70,7 @@
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
- sess.run(result.test_fn(v))
+ self.evaluate(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
# Right now it's 3 or 4 based on whether the read is synchronized.
self.assertEqual(3, self.evaluate(v))
@@ -110,7 +110,7 @@
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
- sess.run(result.test_fn(v))
+ self.evaluate(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(4, self.evaluate(v))
@@ -131,7 +131,7 @@
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
- sess.run(result.test_fn(v))
+ self.evaluate(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(3, self.evaluate(v))
@@ -154,7 +154,7 @@
with self.cached_session() as sess:
v = variable_scope.get_variable('test', initializer=2)
self.evaluate(v.initializer)
- sess.run(result.test_fn(v))
+ self.evaluate(result.test_fn(v))
# TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
self.assertEqual(4, self.evaluate(v))
diff --git a/tensorflow/python/autograph/core/errors_test.py b/tensorflow/python/autograph/core/errors_test.py
index aa6c293..00c8a72 100644
--- a/tensorflow/python/autograph/core/errors_test.py
+++ b/tensorflow/python/autograph/core/errors_test.py
@@ -55,7 +55,7 @@
with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller):
with self.cached_session() as sess:
- sess.run(ops)
+ self.evaluate(ops)
for frame in cm.exception.custom_traceback:
_, _, function_name, _ = frame
@@ -70,7 +70,7 @@
with self.assertRaises(errors.TfRuntimeError) as cm:
with errors.improved_errors(zero_div_caller):
with self.cached_session() as sess:
- sess.run(ops)
+ self.evaluate(ops)
all_function_names = set()
for frame in cm.exception.custom_traceback:
@@ -87,7 +87,7 @@
with self.assertRaises(tf_errors.InvalidArgumentError):
with errors.improved_errors(zero_div_caller):
with self.cached_session() as sess:
- sess.run(ops)
+ self.evaluate(ops)
def test_improved_errors_validation(self):
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD
index 2f9037c..201a888 100644
--- a/tensorflow/python/autograph/impl/BUILD
+++ b/tensorflow/python/autograph/impl/BUILD
@@ -41,7 +41,6 @@
name = "api_test",
srcs = ["api_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":impl",
"//tensorflow/python:client_testlib",
@@ -54,7 +53,6 @@
name = "conversion_test",
srcs = ["conversion_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":impl",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index 48a9307..3dfc12e 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -281,11 +281,10 @@
node, source = parser.parse_entity(f)
node = node.body[0]
- # In general, the output of inspect.getsource is inexact because it uses crude
- # regex matching methods to search the source file. This is particularly
- # problematic for lambda functions, where the entire containing lines are
- # returned. Certain distributions of CPython may also return the enclosing
- # function for local functions.
+ # In general, the output of inspect.getsource is inexact because it uses
+ # regex matching to adjust the exact location around the line number that
+ # CPython records. This is particularly problematic for lambda functions,
+ # where the entire containing lines are returned.
nodes = ast_util.find_matching_definitions(node, f)
if len(nodes) != 1:
if f.__name__ == '<lambda>':
@@ -295,17 +294,11 @@
' matching signature. To avoid ambiguity, define each lambda'
' in a separate expression.'.format(f, source))
else:
- # The inspect.getsource bug is currently known to occur in the Windows
- # integration tests which run Python 3.6.
- # TODO(mdan): Find out eaxctly which distribution of Python is that.
raise ValueError(
'Unable to identify source code of function {}. The source code'
' reported by Python did not include exactly one matching signature:'
- '\n{}\nTo avoid ambiguity, use a unique name for each'
- ' function.\nNote that some distributions of Python may report source'
- ' code incorrectly. It may be possible to avoid that bug by'
- ' organizing the code into smaller units (smaller files, functions or'
- ' classes), or by turning AutoGraph off.'.format(f, source))
+ '\n{}\n. This is an extremely rare occurrence. Please report it to'
+ ' the TensorFlow team.'.format(f, source))
node, = nodes
# TODO(znado): Place inside standard_analysis.
diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
index 6eedd69..1a35efe 100644
--- a/tensorflow/python/autograph/operators/control_flow.py
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -61,7 +61,7 @@
"""
if tensor_util.is_tensor(iter_):
return _known_len_for_stmt(iter_, extra_test, body, init_state)
- elif isinstance(iter_, dataset_ops.Dataset):
+ elif isinstance(iter_, dataset_ops.DatasetV2):
return _dataset_for_stmt(iter_, extra_test, body, init_state)
else:
return _py_for_stmt(iter_, extra_test, body, init_state)
diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
index dc50edb..0433e3f 100644
--- a/tensorflow/python/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -109,8 +109,8 @@
l1 = data_structures.list_append(l, 1)
l2 = data_structures.list_append(l1, 2)
with self.cached_session() as sess:
- self.assertAllEqual(sess.run(l1.stack()), [1])
- self.assertAllEqual(sess.run(l2.stack()), [1, 2])
+ self.assertAllEqual(self.evaluate(l1.stack()), [1])
+ self.assertAllEqual(self.evaluate(l2.stack()), [1, 2])
def test_append_python(self):
l = []
@@ -152,7 +152,7 @@
with self.cached_session() as sess:
t = data_structures.list_stack(l, opts)
- self.assertAllEqual(sess.run(t), self.evaluate(initial_list))
+ self.assertAllEqual(self.evaluate(t), self.evaluate(initial_list))
def test_stack_tensor_list_empty(self):
l = list_ops.empty_tensor_list(
diff --git a/tensorflow/python/autograph/operators/exceptions_test.py b/tensorflow/python/autograph/operators/exceptions_test.py
index 186535d..24d3f1b 100644
--- a/tensorflow/python/autograph/operators/exceptions_test.py
+++ b/tensorflow/python/autograph/operators/exceptions_test.py
@@ -30,7 +30,7 @@
with self.cached_session() as sess:
t = exceptions.assert_stmt(
constant_op.constant(True), lambda: constant_op.constant('ignored'))
- sess.run(t)
+ self.evaluate(t)
def test_assert_tf_triggered(self):
with self.cached_session() as sess:
@@ -40,7 +40,7 @@
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message'):
- sess.run(t)
+ self.evaluate(t)
def test_assert_tf_multiple_printed_values(self):
two_tensors = [
@@ -53,7 +53,7 @@
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'test message.*another message'):
- sess.run(t)
+ self.evaluate(t)
def test_assert_python_untriggered(self):
side_effect_trace = []
diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD
index ddadc6b..ba8ec27 100644
--- a/tensorflow/python/autograph/pyct/BUILD
+++ b/tensorflow/python/autograph/pyct/BUILD
@@ -80,7 +80,6 @@
name = "compiler_test",
srcs = ["compiler_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":pyct",
"//tensorflow/python:client_testlib",
@@ -154,7 +153,6 @@
name = "transformer_test",
srcs = ["transformer_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":pyct",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/autograph/pyct/static_analysis/BUILD b/tensorflow/python/autograph/pyct/static_analysis/BUILD
index 4a4ccdc..5e260c5 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/BUILD
+++ b/tensorflow/python/autograph/pyct/static_analysis/BUILD
@@ -38,7 +38,6 @@
name = "activity_test",
srcs = ["activity_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":static_analysis",
"//tensorflow/python:client_testlib",
@@ -51,7 +50,6 @@
name = "live_values_test",
srcs = ["live_values_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"],
deps = [
":static_analysis",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/autograph/utils/tensor_list_test.py b/tensorflow/python/autograph/utils/tensor_list_test.py
index a5bbd97..c655f77 100644
--- a/tensorflow/python/autograph/utils/tensor_list_test.py
+++ b/tensorflow/python/autograph/utils/tensor_list_test.py
@@ -92,7 +92,7 @@
a2 = l.pop()
c4 = l.count()
with Session() as sess:
- c1, c2, c3, c4, a, a2 = sess.run([c1, c2, c3, c4, a, a2])
+ c1, c2, c3, c4, a, a2 = self.evaluate([c1, c2, c3, c4, a, a2])
self.assertEqual(c1, 1)
self.assertEqual(c2, 2)
self.assertEqual(c3, 1)
@@ -108,7 +108,7 @@
l[0] = b
l1 = l[0]
with self.cached_session() as sess:
- l0, l1, a, b = sess.run([l0, l1, a, b])
+ l0, l1, a, b = self.evaluate([l0, l1, a, b])
self.assertEqual(l0, a)
self.assertEqual(l1, b)
diff --git a/tensorflow/python/autograph/utils/type_check.py b/tensorflow/python/autograph/utils/type_check.py
index 8748abc..ccef7de 100644
--- a/tensorflow/python/autograph/utils/type_check.py
+++ b/tensorflow/python/autograph/utils/type_check.py
@@ -30,4 +30,4 @@
Returns:
True if any *args are TensorFlow types, False if none are.
"""
- return any([tensor_util.is_tensor(a) for a in args])
+ return any(tensor_util.is_tensor(a) for a in args)
diff --git a/tensorflow/python/client/device_lib.i b/tensorflow/python/client/device_lib.i
index 944e855..3e57915 100644
--- a/tensorflow/python/client/device_lib.i
+++ b/tensorflow/python/client/device_lib.i
@@ -48,17 +48,14 @@
std::vector<string> output;
SessionOptions options;
options.config = config;
- std::vector<Device*> devices;
+ std::vector<std::unique_ptr<Device>> devices;
Status status = DeviceFactory::AddDevices(
options, "" /* name_prefix */, &devices);
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
}
- std::vector<std::unique_ptr<Device>> device_holder(devices.begin(),
- devices.end());
-
- for (const Device* device : devices) {
+ for (const std::unique_ptr<Device>& device : devices) {
const DeviceAttributes& attr = device->attributes();
string attr_serialized;
if (!attr.SerializeToString(&attr_serialized)) {
diff --git a/tensorflow/python/client/session_partial_run_test.py b/tensorflow/python/client/session_partial_run_test.py
index 92ca47e..a9bd5ab 100644
--- a/tensorflow/python/client/session_partial_run_test.py
+++ b/tensorflow/python/client/session_partial_run_test.py
@@ -117,7 +117,7 @@
a = constant_op.constant(2.0, dtypes.float32)
b = a * 2
c = b * 3
- r1 = sess.run([b, c])
+ r1 = self.evaluate([b, c])
h = sess.partial_run_setup([b, c], [])
r2 = sess.partial_run(h, [b, c])
self.assertEqual(r1, r2)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 1b8114d..1f43793 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 11, 20)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 11, 28)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/benchmarks/BUILD b/tensorflow/python/data/benchmarks/BUILD
index fd723e0..5b0500e 100644
--- a/tensorflow/python/data/benchmarks/BUILD
+++ b/tensorflow/python/data/benchmarks/BUILD
@@ -7,6 +7,61 @@
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
+ name = "batch_benchmark",
+ srcs = ["batch_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "filter_benchmark",
+ srcs = ["filter_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "from_tensor_slices_benchmark",
+ srcs = ["from_tensor_slices_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "map_benchmark",
+ srcs = ["map_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "range_benchmark",
srcs = ["range_benchmark.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/data/benchmarks/batch_benchmark.py b/tensorflow/python/data/benchmarks/batch_benchmark.py
new file mode 100644
index 0000000..b61ac86
--- /dev/null
+++ b/tensorflow/python/data/benchmarks/batch_benchmark.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks for `tf.data.Dataset.batch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+# TODO(b/119837791): Add eager benchmarks.
+class BatchBenchmark(test.Benchmark):
+ """Benchmarks for `tf.data.Dataset.batch()`."""
+
+ def benchmarkBatchSparse(self):
+ non_zeros_per_row_values = [0, 1, 5, 10, 100]
+ batch_size_values = [1, 32, 64, 128, 1024]
+
+ sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64)
+ batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[])
+
+ dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat(
+ ).batch(batch_size_placeholder)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ for non_zeros_per_row in non_zeros_per_row_values:
+
+ sparse_value = sparse_tensor.SparseTensorValue(
+ indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
+ values=np.arange(non_zeros_per_row, dtype=np.int64),
+ dense_shape=[1000])
+
+ for batch_size in batch_size_values:
+
+ with session.Session() as sess:
+ sess.run(iterator.initializer, feed_dict={
+ sparse_placeholder: sparse_value,
+ batch_size_placeholder: batch_size})
+ # Run five steps to warm up the session caches before taking the
+ # first measurement.
+ for _ in range(5):
+ sess.run(next_element.indices.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.indices.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100.0
+
+ print("Batch sparse dataset non-zeros per row: %d batch_size: %d "
+ "wall time: %f"
+ % (non_zeros_per_row, batch_size, median_wall_time))
+ self.report_benchmark(
+ iters=10000, wall_time=median_wall_time,
+ name="batch_sparse_dataset_nnz_%d_batch_size_%d" % (
+ non_zeros_per_row, batch_size))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/benchmarks/filter_benchmark.py b/tensorflow/python/data/benchmarks/filter_benchmark.py
new file mode 100644
index 0000000..b9acdc7
--- /dev/null
+++ b/tensorflow/python/data/benchmarks/filter_benchmark.py
@@ -0,0 +1,69 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks for `tf.data.Dataset.filter()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+# TODO(b/119837791): Add eager benchmarks.
+class FilterBenchmark(test.Benchmark):
+ """Benchmarks for `tf.data.Dataset.filter()`."""
+
+ def _benchmark(self, predicate, name):
+ with ops.Graph().as_default():
+ dataset = (
+ dataset_ops.Dataset.from_tensors(True).repeat(None).filter(predicate))
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ print("Filter dataset using %s. Median wall time: %f" %
+ (name, median_wall_time))
+ self.report_benchmark(
+ iters=100,
+ wall_time=median_wall_time,
+ name=name)
+
+ def benchmarkSimpleFunction(self):
+ self._benchmark(array_ops.identity, "simple_function")
+
+ def benchmarkReturnComponentOptimization(self):
+ self._benchmark(lambda x: x, "return_component")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/benchmarks/from_tensor_slices_benchmark.py b/tensorflow/python/data/benchmarks/from_tensor_slices_benchmark.py
new file mode 100644
index 0000000..74a2d27
--- /dev/null
+++ b/tensorflow/python/data/benchmarks/from_tensor_slices_benchmark.py
@@ -0,0 +1,188 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks for `tf.data.Dataset.from_tensor_slices()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+# TODO(b/119837791): Add eager benchmarks.
+class FromTensorSlicesBenchmark(test.Benchmark):
+ """Benchmarks for `tf.data.Dataset.from_tensor_slices()`."""
+
+ def benchmarkSliceRepeatBatch(self):
+ input_size = 10000
+ batch_size = 100
+ num_epochs = 100
+
+ input_data = np.random.randn(input_size)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(input_data)
+ .repeat(num_epochs + 1).batch(batch_size))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ sess.run(iterator.initializer)
+ # Run one whole epoch to burn in the computation.
+ for _ in range(input_size // batch_size):
+ sess.run(next_element)
+ deltas = []
+ try:
+ while True:
+ start = time.time()
+ sess.run(next_element)
+ deltas.append(time.time() - start)
+ except errors.OutOfRangeError:
+ pass
+
+ median_wall_time = np.median(deltas)
+ print("Slice/repeat/batch with sess.run() input size: %d batch size: %d "
+ "Median wall time per element: %f" % (input_size, batch_size,
+ median_wall_time))
+ self.report_benchmark(
+ iters=len(deltas),
+ wall_time=median_wall_time,
+ name="slice_repeat_batch_input_%d_batch_%d" % (input_size, batch_size))
+
+ def benchmarkSliceRepeatBatchCallable(self):
+ input_size = 10000
+ batch_size = 100
+ num_epochs = 100
+
+ input_data = np.random.randn(input_size)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(input_data)
+ .repeat(num_epochs + 1).batch(batch_size))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ sess.run(iterator.initializer)
+ get_next_element = sess.make_callable(next_element)
+ # Run one whole epoch to burn in the computation.
+ for _ in range(input_size // batch_size):
+ get_next_element()
+ deltas = []
+ try:
+ while True:
+ start = time.time()
+ get_next_element()
+ deltas.append(time.time() - start)
+ except errors.OutOfRangeError:
+ pass
+
+ median_wall_time = np.median(deltas)
+ print(
+ "Slice/repeat/batch with callable input size: %d batch size: %d Median"
+ " wall time per element: %f" % (input_size, batch_size,
+ median_wall_time))
+ self.report_benchmark(
+ iters=len(deltas),
+ wall_time=median_wall_time,
+ name="slice_repeat_batch_callable_input_%d_batch_%d" %
+ (input_size, batch_size))
+
+ def benchmarkReshapeSliceRepeatCallable(self):
+ input_size = 10000
+ batch_size = 100
+ num_epochs = 100
+
+ input_data = np.random.randn(input_size)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(input_data.reshape(100, 100))
+ .repeat(num_epochs + 1))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ sess.run(iterator.initializer)
+ get_next_element = sess.make_callable(next_element)
+ # Run one whole epoch to burn in the computation.
+ for _ in range(input_size // batch_size):
+ get_next_element()
+ deltas = []
+ try:
+ while True:
+ start = time.time()
+ get_next_element()
+ deltas.append(time.time() - start)
+ except errors.OutOfRangeError:
+ pass
+
+ median_wall_time = np.median(deltas)
+ print("Reshape/slice/repeat with callable input size: %d batch size: %d "
+ "Median wall time per element: %f" % (input_size, batch_size,
+ median_wall_time))
+ self.report_benchmark(
+ iters=len(deltas),
+ wall_time=median_wall_time,
+ name="reshape_slice_repeat_callable_input_%d_batch_%d" %
+ (input_size, batch_size))
+
+ def benchmarkSliceBatchCacheRepeatCallable(self):
+ input_size = 10000
+ batch_size = 100
+ num_epochs = 100
+
+ input_data = np.random.randn(input_size)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(input_data).batch(batch_size)
+ .cache().repeat(num_epochs + 1))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ sess.run(iterator.initializer)
+ get_next_element = sess.make_callable(next_element)
+ # Run one whole epoch to burn in the computation.
+ for _ in range(input_size // batch_size):
+ get_next_element()
+ deltas = []
+ try:
+ while True:
+ start = time.time()
+ get_next_element()
+ deltas.append(time.time() - start)
+ except errors.OutOfRangeError:
+ pass
+
+ median_wall_time = np.median(deltas)
+ print(
+ "Slice/batch/cache/repeat with callable input size: %d batch size: %d "
+ "Median wall time per element: %f"
+ % (input_size, batch_size, median_wall_time))
+ self.report_benchmark(
+ iters=len(deltas),
+ wall_time=median_wall_time,
+ name="slice_batch_cache_repeat_callable_input_%d_batch_%d" %
+ (input_size, batch_size))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/benchmarks/map_benchmark.py b/tensorflow/python/data/benchmarks/map_benchmark.py
new file mode 100644
index 0000000..48294ee
--- /dev/null
+++ b/tensorflow/python/data/benchmarks/map_benchmark.py
@@ -0,0 +1,135 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Bechmarks for `tf.data.Dataset.map()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+# TODO(b/119837791): Add eager benchmarks.
+class MapBenchmark(test.Benchmark):
+ """Bechmarks for `tf.data.Dataset.map()`."""
+
+ def benchmarkChainOfMaps(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda x: x
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset_ops.MapDataset(
+ dataset,
+ map_fn,
+ use_inter_op_parallelism=use_inter_op_parallelism)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ print("Map dataset chain length%s: %d Median wall time: %f" %
+ (print_label, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="map_dataset_chain_length_%d%s" % (chain_length,
+ benchmark_label))
+
+ def benchmarkMapFanOut(self):
+ fan_outs = [1, 2, 5, 10, 20, 50, 100]
+ for fan_out in fan_outs:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda *xs: xs
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(
+ tuple(0 for _ in range(fan_out))).repeat(None)
+ dataset = dataset_ops.MapDataset(
+ dataset,
+ map_fn,
+ use_inter_op_parallelism=use_inter_op_parallelism)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(next_element[0].op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element[0].op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ print("Map dataset fan out%s: %d Median wall time: %f" %
+ (print_label, fan_out, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="map_dataset_fan_out_%d%s" % (fan_out, benchmark_label))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py
index 126c2be..12aa0a0 100644
--- a/tensorflow/python/data/experimental/__init__.py
+++ b/tensorflow/python/data/experimental/__init__.py
@@ -32,6 +32,7 @@
@@StatsAggregator
@@StatsOptions
@@TFRecordWriter
+@@ThreadingOptions
@@bucket_by_sequence_length
@@choose_from_datasets
@@ -101,6 +102,7 @@
from tensorflow.python.data.experimental.ops.stats_aggregator import StatsAggregator
from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
from tensorflow.python.data.experimental.ops.stats_options import StatsOptions
+from tensorflow.python.data.experimental.ops.threading_options import ThreadingOptions
from tensorflow.python.data.experimental.ops.unique import unique
from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD
index b89fbe7..8175116 100644
--- a/tensorflow/python/data/experimental/benchmarks/BUILD
+++ b/tensorflow/python/data/experimental/benchmarks/BUILD
@@ -8,15 +8,12 @@
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
- name = "map_and_batch_benchmark",
- size = "medium",
- srcs = ["map_and_batch_benchmark.py"],
+ name = "autotune_benchmark",
+ srcs = ["autotune_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:random_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:session",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:optimization",
@@ -26,17 +23,102 @@
)
py_test(
- name = "map_benchmark",
- size = "medium",
- srcs = ["map_benchmark.py"],
+ name = "csv_dataset_benchmark",
+ srcs = ["csv_dataset_benchmark.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/ops:readers",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "map_and_batch_benchmark",
+ srcs = ["map_and_batch_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "map_vectorization_benchmark",
+ srcs = ["map_vectorization_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "matching_files_benchmark",
+ size = "small",
+ srcs = ["matching_files_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:matching_files",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "optimize_benchmark",
+ srcs = ["optimize_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "unbatch_benchmark",
+ srcs = ["unbatch_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:session",
"//tensorflow/python/data/experimental/ops:batching",
- "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
diff --git a/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py b/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py
new file mode 100644
index 0000000..b00e918
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/autotune_benchmark.py
@@ -0,0 +1,187 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks for autotuning performance knobs."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class AutotuneBenchmark(test.Benchmark):
+ """Benchmarks for autotuning performance knobs."""
+
+ def benchmarkMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(
+ math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+ self.report_benchmark(
+ iters=1000, wall_time=np.median(deltas), name="map_autotune")
+
+ def benchmarkMapAndBatch(self):
+ self._benchmarkMapAndBatch(numa_aware=False)
+ self._benchmarkMapAndBatch(numa_aware=True)
+
+ def _benchmarkMapAndBatch(self, numa_aware):
+ batch_size = 16
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=optimization.AUTOTUNE,
+ batch_size=batch_size))
+ options = dataset_ops.Options()
+ options.experimental_numa_aware = numa_aware
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ self.report_benchmark(
+ iters=100,
+ wall_time=np.median(deltas),
+ name=("numa_" if numa_aware else "") + "map_and_batch_autotune")
+
+ def benchmarkInterleave(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset,
+ cycle_length=10,
+ num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=np.median(deltas),
+ name="interleave_autotune")
+
+ def benchmarkMapAndInterleave(self):
+ k = 1024 * 1024
+ a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
+ b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
+ c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
+ dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
+
+ def f1(a, b, c):
+ x, y = a
+ return math_ops.matmul(x, y), b, c
+
+ def f2(a, b, c):
+ x, y = b
+ return a, math_ops.matmul(x, y), c
+
+ def f3(a, b, c):
+ x, y = c
+ return a, b, math_ops.matmul(x, y)
+
+ dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset,
+ num_parallel_calls=optimization.AUTOTUNE,
+ cycle_length=2)
+
+ dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset,
+ num_parallel_calls=optimization.AUTOTUNE,
+ cycle_length=2)
+
+ dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(get_next)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+ self.report_benchmark(
+ iters=100,
+ wall_time=np.median(deltas),
+ name="map_and_interleave_autotune")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py b/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py
new file mode 100644
index 0000000..7eebf49
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/csv_dataset_benchmark.py
@@ -0,0 +1,129 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks for `tf.data.experimental.CsvDataset`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import string
+import tempfile
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+class CsvDatasetBenchmark(test.Benchmark):
+ """Benchmarks for `tf.data.experimental.CsvDataset`."""
+
+ FLOAT_VAL = '1.23456E12'
+ STR_VAL = string.ascii_letters * 10
+
+ def _setUp(self, str_val):
+ # Since this isn't test.TestCase, have to manually create a test dir
+ gfile.MakeDirs(googletest.GetTempDir())
+ self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
+
+ self._num_cols = [4, 64, 256]
+ self._num_per_iter = 5000
+ self._filenames = []
+ for n in self._num_cols:
+ fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
+ with open(fn, 'wb') as f:
+ # Just write 100 rows and use `repeat`... Assumes the cost
+ # of creating an iterator is not significant
+ row = ','.join([str_val for _ in range(n)])
+ f.write('\n'.join([row for _ in range(100)]))
+ self._filenames.append(fn)
+
+ def _tearDown(self):
+ gfile.DeleteRecursively(self._temp_dir)
+
+ def _runBenchmark(self, dataset, num_cols, prefix):
+ dataset = dataset.skip(self._num_per_iter - 1)
+ deltas = []
+ for _ in range(10):
+ next_element = dataset.make_one_shot_iterator().get_next()
+ with session.Session() as sess:
+ start = time.time()
+ # NOTE: This depends on the underlying implementation of skip, to have
+ # the net effect of calling `GetNext` num_per_iter times on the
+ # input dataset. We do it this way (instead of a python for loop, or
+ # batching N inputs in one iter) so that the overhead from session.run
+ # or batch doesn't dominate. If we eventually optimize skip, this has
+ # to change.
+ sess.run(next_element)
+ end = time.time()
+ deltas.append(end - start)
+ # Median wall time per CSV record read and decoded
+ median_wall_time = np.median(deltas) / self._num_per_iter
+ print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols,
+ median_wall_time))
+ self.report_benchmark(
+ iters=self._num_per_iter,
+ wall_time=median_wall_time,
+ name='%s_with_cols_%d' % (prefix, num_cols))
+
+ def benchmarkMapWithFloats(self):
+ self._setUp(self.FLOAT_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [[0.0]] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv')
+ self._tearDown()
+
+ def benchmarkMapWithStrings(self):
+ self._setUp(self.STR_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [['']] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
+ self._tearDown()
+
+ def benchmarkCsvDatasetWithFloats(self):
+ self._setUp(self.FLOAT_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [[0.0]] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset')
+ self._tearDown()
+
+ def benchmarkCsvDatasetWithStrings(self):
+ self._setUp(self.STR_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [['']] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
+ self._tearDown()
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/benchmarks/map_and_batch_benchmark.py b/tensorflow/python/data/experimental/benchmarks/map_and_batch_benchmark.py
index a90156c..1e8dd0f 100644
--- a/tensorflow/python/data/experimental/benchmarks/map_and_batch_benchmark.py
+++ b/tensorflow/python/data/experimental/benchmarks/map_and_batch_benchmark.py
@@ -17,6 +17,8 @@
from __future__ import division
from __future__ import print_function
+import hashlib
+import itertools
import time
import numpy as np
@@ -25,11 +27,15 @@
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
+_NUMPY_RANDOM_SEED = 42
+
class MapAndBatchBenchmark(test.Benchmark):
"""Benchmarks for `tf.data.experimental.map_and_batch()`."""
@@ -89,6 +95,129 @@
name="benchmark_batch_dense_dataset_nnz_%d_batch_size_%d" % (
np.prod(shape), batch_size))
+ def benchmarkMapAndBatchChainingVersusFusing(self):
+ """Compares the performance of chaining and fusing map and batch.
+
+ NOTE: It is recommended to build the benchmark with
+ `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
+ and execute it on a machine with at least 32 CPU cores.
+ """
+
+ # Sequential pipeline configurations.
+ seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
+ seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
+
+ # Parallel pipeline configuration.
+ par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
+ par_batch_size_series = itertools.product([32], [32], [1],
+ [128, 256, 512, 1024])
+ par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
+ par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
+
+ def name(method, label, num_calls, inter_op, element_size, batch_size):
+ return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
+ method,
+ hashlib.sha1(label).hexdigest()[:8],
+ num_calls,
+ inter_op,
+ element_size,
+ batch_size,
+ ))
+
+ def benchmark(label, series):
+ """Runs benchmark the given series."""
+
+ print("%s:" % label)
+
+ def make_base_dataset(element_size):
+ k = 1024 * 1024
+ x = constant_op.constant(np.random.rand(element_size, 4 * k))
+ y = constant_op.constant(np.random.rand(4 * k, 1))
+ return dataset_ops.Dataset.range(1000000000000).map(lambda _: (x, y))
+
+ for num_calls, inter_op, element_size, batch_size in series:
+
+ num_iters = 1024 // (
+ (element_size * batch_size) // min(num_calls, inter_op))
+ dataset = make_base_dataset(element_size)
+ chained_dataset = dataset.map(
+ math_ops.matmul,
+ num_parallel_calls=num_calls).batch(batch_size=batch_size)
+ chained_iterator = chained_dataset.make_one_shot_iterator()
+ chained_get_next = chained_iterator.get_next()
+
+ chained_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+ for _ in range(5):
+ sess.run(chained_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(chained_get_next.op)
+ end = time.time()
+ chained_deltas.append(end - start)
+
+ fused_dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=num_calls,
+ batch_size=batch_size))
+ fused_iterator = fused_dataset.make_one_shot_iterator()
+ fused_get_next = fused_iterator.get_next()
+
+ fused_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+
+ for _ in range(5):
+ sess.run(fused_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(fused_get_next.op)
+ end = time.time()
+ fused_deltas.append(end - start)
+
+ print(
+ "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
+ "element size: %d, num iters: %d\nchained wall time: %f (median), "
+ "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
+ "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
+ "chained/fused: %.2fx (median), %.2fx (mean)" %
+ (batch_size, num_calls, inter_op, element_size, num_iters,
+ np.median(chained_deltas), np.mean(chained_deltas),
+ np.std(chained_deltas), np.min(chained_deltas),
+ np.max(chained_deltas), np.median(fused_deltas),
+ np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
+ np.max(fused_deltas),
+ np.median(chained_deltas) / np.median(fused_deltas),
+ np.mean(chained_deltas) / np.mean(fused_deltas)))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(chained_deltas),
+ name=name("chained", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(fused_deltas),
+ name=name("fused", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ print()
+
+ np.random.seed(_NUMPY_RANDOM_SEED)
+ benchmark("Sequential element size evaluation", seq_elem_size_series)
+ benchmark("Sequential batch size evaluation", seq_batch_size_series)
+ benchmark("Parallel element size evaluation", par_elem_size_series)
+ benchmark("Parallel batch size evaluation", par_batch_size_series)
+ benchmark("Transformation parallelism evaluation", par_num_calls_series)
+ benchmark("Threadpool size evaluation", par_inter_op_series)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/benchmarks/map_benchmark.py b/tensorflow/python/data/experimental/benchmarks/map_benchmark.py
deleted file mode 100644
index ad253cf..0000000
--- a/tensorflow/python/data/experimental/benchmarks/map_benchmark.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import hashlib
-import itertools
-import time
-
-import numpy as np
-
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.data.experimental.ops import batching
-from tensorflow.python.data.experimental.ops import optimization
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-_NUMPY_RANDOM_SEED = 42
-
-
-class MapDatasetBenchmark(test.Benchmark):
-
- # The purpose of this benchmark is to compare the performance of chaining vs
- # fusing of the map and batch transformations across various configurations.
- #
- # NOTE: It is recommended to build the benchmark with
- # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
- # and execute it on a machine with at least 32 CPU cores.
- def benchmarkMapAndBatch(self):
-
- # Sequential pipeline configurations.
- seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
- seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
-
- # Parallel pipeline configuration.
- par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
- par_batch_size_series = itertools.product([32], [32], [1],
- [128, 256, 512, 1024])
- par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
- par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
-
- def name(method, label, num_calls, inter_op, element_size, batch_size):
- return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
- method,
- hashlib.sha1(label).hexdigest(),
- num_calls,
- inter_op,
- element_size,
- batch_size,
- ))
-
- def benchmark(label, series):
-
- print("%s:" % label)
- for num_calls, inter_op, element_size, batch_size in series:
-
- num_iters = 1024 // (
- (element_size * batch_size) // min(num_calls, inter_op))
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
- element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
-
- chained_dataset = dataset.map(
- math_ops.matmul,
- num_parallel_calls=num_calls).batch(batch_size=batch_size)
- chained_iterator = chained_dataset.make_one_shot_iterator()
- chained_get_next = chained_iterator.get_next()
-
- chained_deltas = []
- with session.Session(
- config=config_pb2.ConfigProto(
- inter_op_parallelism_threads=inter_op,
- use_per_session_threads=True)) as sess:
- for _ in range(5):
- sess.run(chained_get_next.op)
- for _ in range(num_iters):
- start = time.time()
- sess.run(chained_get_next.op)
- end = time.time()
- chained_deltas.append(end - start)
-
- fused_dataset = dataset.apply(
- batching.map_and_batch(
- math_ops.matmul,
- num_parallel_calls=num_calls,
- batch_size=batch_size))
- fused_iterator = fused_dataset.make_one_shot_iterator()
- fused_get_next = fused_iterator.get_next()
-
- fused_deltas = []
- with session.Session(
- config=config_pb2.ConfigProto(
- inter_op_parallelism_threads=inter_op,
- use_per_session_threads=True)) as sess:
-
- for _ in range(5):
- sess.run(fused_get_next.op)
- for _ in range(num_iters):
- start = time.time()
- sess.run(fused_get_next.op)
- end = time.time()
- fused_deltas.append(end - start)
-
- print(
- "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
- "element size: %d, num iters: %d\nchained wall time: %f (median), "
- "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
- "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
- "chained/fused: %.2fx (median), %.2fx (mean)" %
- (batch_size, num_calls, inter_op, element_size, num_iters,
- np.median(chained_deltas), np.mean(chained_deltas),
- np.std(chained_deltas), np.min(chained_deltas),
- np.max(chained_deltas), np.median(fused_deltas),
- np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
- np.max(fused_deltas),
- np.median(chained_deltas) / np.median(fused_deltas),
- np.mean(chained_deltas) / np.mean(fused_deltas)))
-
- self.report_benchmark(
- iters=num_iters,
- wall_time=np.median(chained_deltas),
- name=name("chained", label, num_calls, inter_op, element_size,
- batch_size))
-
- self.report_benchmark(
- iters=num_iters,
- wall_time=np.median(fused_deltas),
- name=name("fused", label, num_calls, inter_op, element_size,
- batch_size))
-
- print("")
-
- np.random.seed(_NUMPY_RANDOM_SEED)
- benchmark("Sequential element size evaluation", seq_elem_size_series)
- benchmark("Sequential batch size evaluation", seq_batch_size_series)
- benchmark("Parallel element size evaluation", par_elem_size_series)
- benchmark("Parallel batch size evaluation", par_batch_size_series)
- benchmark("Transformation parallelism evaluation", par_num_calls_series)
- benchmark("Threadpool size evaluation", par_inter_op_series)
-
- # This benchmark compares the performance of pipeline with multiple chained
- # maps with and without map fusion.
- def benchmarkChainOfMaps(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- self._benchmarkChainOfMaps(chain_length, False)
- self._benchmarkChainOfMaps(chain_length, True)
-
- def _benchmarkChainOfMaps(self, chain_length, optimize_dataset):
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.map(lambda x: x)
- if optimize_dataset:
- dataset = dataset.apply(optimization.optimize(["map_fusion"]))
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- opt_mark = "opt" if optimize_dataset else "no-opt"
- print("Map dataset {} chain length: {} Median wall time: {}".format(
- opt_mark, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_map_dataset_chain_latency_{}_{}".format(
- opt_mark, chain_length))
-
-
-class MapAndFilterBenchmark(test.Benchmark):
-
- # This benchmark compares the performance of pipeline with multiple chained
- # map + filter with and without map fusion.
- def benchmarkMapAndFilter(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- self._benchmarkMapAndFilter(chain_length, False)
- self._benchmarkMapAndFilter(chain_length, True)
-
- def _benchmarkMapAndFilter(self, chain_length, optimize_dataset):
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.map(lambda x: x + 5).filter(
- lambda x: math_ops.greater_equal(x - 5, 0))
- if optimize_dataset:
- dataset = dataset.apply(
- optimization.optimize(["map_and_filter_fusion"]))
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(10):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- opt_mark = "opt" if optimize_dataset else "no-opt"
- print("Map and filter dataset {} chain length: {} Median wall time: {}".
- format(opt_mark, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format(
- opt_mark, chain_length))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/benchmarks/map_vectorization_benchmark.py b/tensorflow/python/data/experimental/benchmarks/map_vectorization_benchmark.py
new file mode 100644
index 0000000..0c3ac8b
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/map_vectorization_benchmark.py
@@ -0,0 +1,194 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks for the `MapVectorization` optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+
+
+def _generate_csv_test_case():
+ """Generates a `decode_csv()` test case."""
+
+ def csv_factory():
+ return dataset_ops.Dataset.from_tensor_slices(["1.0:2:a",
+ "2.4:5:c"]).repeat(5)
+
+ def decode_csv_fn(x):
+ return parsing_ops.decode_csv(
+ x,
+ record_defaults=[
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.string)
+ ],
+ field_delim=":")
+
+ return decode_csv_fn, csv_factory
+
+
+def _generate_parse_single_example_test_case():
+ """Generates a `parse_single_example()` test case."""
+
+ def parse_example_factory():
+ """Parse example factory."""
+
+ def _int64_feature(*values):
+ return feature_pb2.Feature(int64_list=feature_pb2.Int64List(value=values))
+
+ def _bytes_feature(*values):
+ return feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[v.encode("utf-8") for v in values]))
+
+ return dataset_ops.Dataset.from_tensor_slices(
+ constant_op.constant([
+ example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ "dense_int": _int64_feature(i),
+ "dense_str": _bytes_feature(str(i)),
+ "sparse_int": _int64_feature(i, i * 2, i * 4, i * 8),
+ "sparse_str": _bytes_feature(*["abc"] * i)
+ })).SerializeToString() for i in range(10)
+ ]))
+
+ def parse_single_example_fn(x):
+ features = {
+ "dense_int": parsing_ops.FixedLenFeature((), dtypes.int64, 0),
+ "dense_str": parsing_ops.FixedLenFeature((), dtypes.string, ""),
+ "sparse_int": parsing_ops.VarLenFeature(dtypes.int64),
+ "sparse_str": parsing_ops.VarLenFeature(dtypes.string),
+ }
+ return parsing_ops.parse_single_example(x, features)
+
+ return parse_single_example_fn, parse_example_factory
+
+
+# TODO(rachelim): Add a benchmark for more expensive transformations, such as
+# vgg_preprocessing.
+class MapVectorizationBenchmark(test.Benchmark):
+ """Benchmarks for the `MapVectorization` optimization."""
+
+ def _run(self, x, num_iters=100, name=None):
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ # Warm up session...
+ sess.run(x)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(x)
+ end = time.time()
+ deltas.append(end - start)
+ median_time = np.median(deltas)
+ self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
+ return median_time
+
+ def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
+ num_elems = int(np.sum([np.prod(x) for x in input_size]))
+ name_template = "{}__batch_size_{}_input_element_size_{}_{}"
+ unoptimized = input_dataset.map(map_fn).batch(batch_size)
+ unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
+
+ optimized = input_dataset.map(map_fn).batch(batch_size)
+ options = dataset_ops.Options()
+ options.experimental_map_vectorization = True
+ optimized = optimized.with_options(options)
+ optimized_op = optimized.make_one_shot_iterator().get_next()
+
+ unoptimized_time = self._run(
+ unoptimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
+ optimized_time = self._run(
+ optimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "optimized"))
+
+ print("Batch size: {}\n"
+ "Input element size: {}\n"
+ "Transformation: {}\n"
+ "Speedup: {}\n".format(batch_size, input_size, str_id,
+ (unoptimized_time / optimized_time)))
+
+ # Known cheap functions
+ def benchmarkIdentity(self):
+ self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
+ "identity")
+
+ def benchmarkAddConst(self):
+ self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+
+ def benchmarkReturnConst(self):
+ self._benchmark_helper(lambda *args: [constant_op.constant(2)], "ret_const")
+
+ def benchmarkSelect(self):
+ self._benchmark_helper(lambda *args: args[0], "select")
+
+ def benchmarkCast(self):
+ self._benchmark_helper(
+ lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
+
+ def benchmarkReshape(self):
+ self._benchmark_helper(
+ lambda *args: [array_ops.reshape(x, (-1, 30)) for x in args], "reshape")
+
+ def benchmarkDecodeCSV(self):
+ csv_fn, csv_factory = _generate_csv_test_case()
+ self._benchmark_helper(csv_fn, "decode_csv", lambda: [csv_factory()])
+
+ def benchmarkParseSingleExample(self):
+ # NOTE: Since we haven't implemented a vectorizer for "SerializeSparse",
+ # this function is only naively vectorized.
+ parse_fn, parse_factory = _generate_parse_single_example_test_case()
+
+ self._benchmark_helper(parse_fn, "parse_single_example",
+ lambda: [parse_factory()])
+
+ def _default_dataset_factory(self):
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ for sz in input_sizes:
+ yield dataset_ops.Dataset.from_tensor_slices(np.random.rand(*sz))
+
+ def _benchmark_helper(self, map_fn, str_id, base_dataset_factory=None):
+ if base_dataset_factory is None:
+ base_dataset_factory = self._default_dataset_factory
+
+ batch_size = 1000
+ for base_dataset in base_dataset_factory():
+ base_dataset = base_dataset.repeat()
+ input_size = [
+ tuple(shape.as_list())
+ for shape in nest.flatten(base_dataset.output_shapes)
+ ]
+ self._compare(base_dataset, map_fn, batch_size, input_size, str_id)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/benchmarks/matching_files_benchmark.py b/tensorflow/python/data/experimental/benchmarks/matching_files_benchmark.py
new file mode 100644
index 0000000..2eb5561
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/matching_files_benchmark.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for the experimental `MatchingFilesDataset`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import tempfile
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import matching_files
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+class MatchingFilesBenchmark(test.Benchmark):
+ """Benchmark for the experimental `MatchingFilesDataset`."""
+
+ def benchmarkNestedDirectories(self):
+ tmp_dir = tempfile.mkdtemp()
+ width = 500
+ depth = 10
+ for i in range(width):
+ for j in range(depth):
+ new_base = os.path.join(tmp_dir, str(i),
+ *[str(dir_name) for dir_name in range(j)])
+ os.makedirs(new_base)
+ child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log']
+ for f in child_files:
+ filename = os.path.join(new_base, f)
+ open(filename, 'w').close()
+
+ patterns = [
+ os.path.join(tmp_dir, os.path.join(*['**'
+ for _ in range(depth)]), suffix)
+ for suffix in ['*.txt', '*.log']
+ ]
+
+ deltas = []
+ iters = 3
+ for _ in range(iters):
+ with ops.Graph().as_default():
+ dataset = matching_files.MatchingFilesDataset(patterns)
+ next_element = dataset.make_one_shot_iterator().get_next()
+
+ with session.Session() as sess:
+ sub_deltas = []
+ while True:
+ try:
+ start = time.time()
+ sess.run(next_element)
+ end = time.time()
+ sub_deltas.append(end - start)
+ except errors.OutOfRangeError:
+ break
+ deltas.append(sub_deltas)
+
+ median_deltas = np.median(deltas, axis=0)
+ print('Nested directory size (width*depth): %d*%d Median wall time: '
+ '%fs (read first filename), %fs (read second filename), avg %fs'
+ ' (read %d more filenames)' %
+ (width, depth, median_deltas[0], median_deltas[1],
+ np.average(median_deltas[2:]), len(median_deltas) - 2))
+ self.report_benchmark(
+ iters=iters,
+ wall_time=np.sum(median_deltas),
+ extras={
+ 'read first file:':
+ median_deltas[0],
+ 'read second file:':
+ median_deltas[1],
+ 'avg time for reading %d more filenames:' %
+ (len(median_deltas) - 2):
+ np.average(median_deltas[2:])
+ },
+ name='dataset_nested_directory(%d*%d)' %
+ (width, depth))
+
+ shutil.rmtree(tmp_dir, ignore_errors=True)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py b/tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py
new file mode 100644
index 0000000..0eca97d
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/optimize_benchmark.py
@@ -0,0 +1,120 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks for static optimizations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class OptimizationBenchmark(test.Benchmark):
+ """Benchmarks for static optimizations."""
+
+ def benchmarkMapFusion(self):
+ """Evaluates performance map of fusion."""
+
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkMapFusion(chain_length, False)
+ self._benchmarkMapFusion(chain_length, True)
+
+ def _benchmarkMapFusion(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.map(lambda x: x)
+ if optimize_dataset:
+ options = dataset_ops.Options()
+ options.experimental_map_fusion = True
+ dataset = dataset.with_options(options)
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "noopt"
+ print("Map dataset {} chain length: {} Median wall time: {}".format(
+ opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=100,
+ wall_time=median_wall_time,
+ name="map_fusion_{}_chain_length_{}".format(
+ opt_mark, chain_length))
+
+ def benchmarkMapAndFilterFusion(self):
+ """Evaluates performance map of fusion."""
+
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkMapAndFilterFusion(chain_length, False)
+ self._benchmarkMapAndFilterFusion(chain_length, True)
+
+ def _benchmarkMapAndFilterFusion(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.map(lambda x: x + 5).filter(
+ lambda x: math_ops.greater_equal(x - 5, 0))
+ if optimize_dataset:
+ options = dataset_ops.Options()
+ options.experimental_map_and_filter_fusion = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(10):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "noopt"
+ print("Map and filter dataset {} chain length: {} Median wall time: {}"
+ .format(opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=100,
+ wall_time=median_wall_time,
+ name="map_and_filter_fusion_{}_chain_length_{}".format(
+ opt_mark, chain_length))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/benchmarks/unbatch_benchmark.py b/tensorflow/python/data/experimental/benchmarks/unbatch_benchmark.py
new file mode 100644
index 0000000..c40d479
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/unbatch_benchmark.py
@@ -0,0 +1,107 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.unbatch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class UnbatchBenchmark(test.Benchmark):
+ """Benchmarks for `tf.data.experimental.unbatch()`."""
+
+ def benchmarkNativeUnbatch(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.apply(batching.unbatch())
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (native) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="native_batch_size_%d" %
+ batch_size)
+
+ # Include a benchmark of the previous `unbatch()` implementation that uses
+ # a composition of more primitive ops. Eventually we'd hope to generate code
+ # that is as good in both cases.
+ def benchmarkOldUnbatchImplementation(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (unfused) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="unfused_batch_size_%d" %
+ batch_size)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index c9b11a2..6b22f9b 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -72,15 +72,11 @@
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:session",
"//tensorflow/python/data/experimental/ops:error_ops",
"//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:context",
- "//third_party/py/numpy",
],
)
@@ -372,6 +368,25 @@
)
py_test(
+ name = "matching_files_test",
+ size = "small",
+ srcs = ["matching_files_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:matching_files",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "override_threadpool_test",
size = "small",
srcs = ["override_threadpool_test.py"],
@@ -618,7 +633,9 @@
size = "medium",
srcs = ["stats_dataset_ops_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_pip",
+ ],
deps = [
":reader_dataset_ops_test_base",
":stats_dataset_test_base",
diff --git a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
deleted file mode 100644
index dbb780c..0000000
--- a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
+++ /dev/null
@@ -1,690 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import math
-import time
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.data.experimental.ops import batching
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def testDenseToSparseBatchDataset(self):
- components = np.random.randint(12, size=(100,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4,
- [12])).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
-
- for start in range(0, len(components), 4):
- results = self.evaluate(get_next)
- self.assertAllEqual([[i, j]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)], results.indices)
- self.assertAllEqual(
- [c for c in components[start:start + 4] for _ in range(c)],
- results.values)
- self.assertAllEqual([min(4,
- len(components) - start), 12],
- results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithUnknownShape(self):
- components = np.random.randint(5, size=(40,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.fill([x, x], x)).apply(
- batching.dense_to_sparse_batch(
- 4, [5, None])).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
-
- for start in range(0, len(components), 4):
- results = self.evaluate(get_next)
- self.assertAllEqual([[i, j, z]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)
- for z in range(c)], results.indices)
- self.assertAllEqual([
- c for c in components[start:start + 4] for _ in range(c)
- for _ in range(c)
- ], results.values)
- self.assertAllEqual([
- min(4,
- len(components) - start), 5,
- np.max(components[start:start + 4])
- ], results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithInvalidShape(self):
- input_tensor = array_ops.constant([[1]])
- with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4,
- [-2])).make_initializable_iterator()
-
- def testDenseToSparseBatchDatasetShapeErrors(self):
- input_tensor = array_ops.placeholder(dtypes.int32)
- iterator = (
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4,
- [12])).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Initialize with an input tensor of incompatible rank.
- sess.run(init_op, feed_dict={input_tensor: [[1]]})
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "incompatible with the row shape"):
- sess.run(get_next)
-
- # Initialize with an input tensor that is larger than `row_shape`.
- sess.run(init_op, feed_dict={input_tensor: range(13)})
- with self.assertRaisesRegexp(errors.DataLossError,
- "larger than the row shape"):
- sess.run(get_next)
-
- def testUnbatchWithUnknownRankInput(self):
- placeholder = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
- batching.unbatch())
- iterator = dataset.make_initializable_iterator()
- next_elem = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
- for i in range(4):
- self.assertEqual(i, self.evaluate(next_elem))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_elem)
-
- def testUnbatchScalarDataset(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = (dtypes.int32,) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i,) * 3, self.evaluate(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithStrings(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
- expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors(st)
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- st_row = self.evaluate(next_element)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchDatasetWithDenseAndSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- dense_elem, st_row = self.evaluate(next_element)
- self.assertEqual(i, dense_elem)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchSingleElementTupleDataset(self):
- data = tuple([(math_ops.range(10),) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32,),) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i,),) * 3, self.evaluate(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchMultiElementTupleDataset(self):
- data = tuple([(math_ops.range(10 * i, 10 * i + 10),
- array_ops.fill([10], "hi")) for i in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32, dtypes.string),) * 3
- data = data.batch(2)
- self.assertAllEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertAllEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
- sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchEmpty(self):
- data = dataset_ops.Dataset.from_tensors(
- (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
- constant_op.constant([], shape=[0, 4, 0])))
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchStaticShapeMismatch(self):
- data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
- np.arange(9)))
- with self.assertRaises(ValueError):
- data.apply(batching.unbatch())
-
- def testUnbatchDynamicShapeMismatch(self):
- ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
- ph2 = array_ops.placeholder(dtypes.int32, shape=None)
- data = dataset_ops.Dataset.from_tensors((ph1, ph2))
- data = data.apply(batching.unbatch())
- iterator = data.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- # Mismatch in the 0th dimension.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: np.arange(8).astype(np.int32)
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- # No 0th dimension (i.e. scalar value) for one component.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: 7
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- @parameterized.named_parameters(
- ("Default", None, None),
- ("SequentialCalls", 1, None),
- ("ParallelCalls", 2, None),
- ("ParallelBatches", None, 10),
- )
- def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
- """Test a dataset that maps a TF function across its input elements."""
- # The pipeline is TensorSliceDataset ->
- # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- count = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_calls=num_parallel_calls,
- num_parallel_batches=num_parallel_batches))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([[None] + list(c.shape[1:]) for c in components],
- [t.shape.as_list() for t in get_next])
-
- with self.cached_session() as sess:
- # Batch of a finite input, where the batch_size divides the
- # total number of elements.
- sess.run(init_op, feed_dict={count: 28, batch_size: 14})
- num_batches = (28 * 7) // 14
- for i in range(num_batches):
- result = self.evaluate(get_next)
- for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i * 14 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of a finite input, where the batch_size does not
- # divide the total number of elements.
- sess.run(init_op, feed_dict={count: 14, batch_size: 8})
-
- # We expect (num_batches - 1) full-sized batches.
- num_batches = int(math.ceil((14 * 7) / 8))
- for i in range(num_batches - 1):
- result = self.evaluate(get_next)
- for component, result_component in zip(components, result):
- for j in range(8):
- self.assertAllEqual(component[(i * 8 + j) % 7]**2,
- result_component[j])
- result = self.evaluate(get_next)
- for component, result_component in zip(components, result):
- for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of an empty input should fail straight away.
- sess.run(init_op, feed_dict={count: 0, batch_size: 8})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Empty batch should be an initialization time error.
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, batch_size: 0})
-
- @parameterized.named_parameters(
- ("Even", False),
- ("Uneven", True),
- )
- def testMapAndBatchPartialBatch(self, drop_remainder):
- iterator = (
- dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]),
- batch_size=4,
- drop_remainder=drop_remainder)).make_one_shot_iterator())
- if drop_remainder:
- self.assertEqual([4, 1], iterator.output_shapes.as_list())
- else:
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
- if not drop_remainder:
- self.assertAllEqual([[64], [81]], self.evaluate(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchYieldsPartialBatch(self):
- iterator = (
- dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(lambda x: array_ops.reshape(x * x, [1]),
- 4)).make_one_shot_iterator())
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
- self.assertAllEqual([[64], [81]], self.evaluate(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchParallelGetNext(self):
- iterator = (
- dataset_ops.Dataset.range(50000).apply(
- batching.map_and_batch(lambda x: x,
- batch_size=100)).make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(5):
- got = self.evaluate(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchParallelGetNextDropRemainder(self):
- iterator = (
- dataset_ops.Dataset.range(49999).apply(
- batching.map_and_batch(
- lambda x: x, batch_size=100,
- drop_remainder=True)).make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(4):
- got = self.evaluate(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(2):
- actual = self.evaluate(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testMapAndBatchFails(self):
- """Test a dataset that maps a TF function across its input elements."""
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.check_numerics(
- constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(batching.map_and_batch(
- lambda x: x, batch_size)).make_initializable_iterator())
- init_op = iterator.initializer
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
- sess.run(init_op, feed_dict={batch_size: 14})
-
- def testMapAndBatchShapeMismatch(self):
- """Test a dataset that maps a TF function across its input elements."""
-
- def generator():
- yield [1]
- yield [2]
- yield [3]
- yield [[4, 5, 6]]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator, output_types=dtypes.int32)
- batch_size = 4
- iterator = (
- dataset.apply(batching.map_and_batch(
- lambda x: x, batch_size)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- self.evaluate(init_op)
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "number of elements does not match"):
- sess.run(get_next)
-
- def testMapAndBatchImplicitDispose(self):
- # Tests whether a map and batch dataset will be cleaned up correctly when
- # the pipeline does not run it until exhaustion.
- # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
- # MapAndBatchDataset(f=square_3, batch_size=100).
- components = (np.arange(1000),
- np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
- np.array(37.0) * np.arange(1000))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
- 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
- dataset = dataset.prefetch(5)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for _ in range(3):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", 0),
- ("2", 5),
- ("3", 10),
- ("4", 90),
- ("5", 95),
- ("6", 99),
- )
- def testMapAndBatchOutOfRangeError(self, threshold):
-
- def raising_py_fn(i):
- if i >= threshold:
- raise StopIteration()
- else:
- return i
-
- iterator = (
- dataset_ops.Dataset.range(100).apply(
- batching.map_and_batch(
- lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
- batch_size=10)).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(threshold // 10):
- self.assertAllEqual([i * 10 + j for j in range(10)],
- self.evaluate(get_next))
- if threshold % 10 != 0:
- self.assertAllEqual(
- [threshold // 10 * 10 + j for j in range(threshold % 10)],
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", False, dtypes.bool),
- ("2", -42, dtypes.int8),
- ("3", -42, dtypes.int16),
- ("4", -42, dtypes.int32),
- ("5", -42, dtypes.int64),
- ("6", 42, dtypes.uint8),
- ("7", 42, dtypes.uint16),
- ("8", 42.0, dtypes.float16),
- ("9", 42.0, dtypes.float32),
- ("10", 42.0, dtypes.float64),
- ("11", b"hello", dtypes.string),
- )
- def testMapAndBatchTypes(self, element, dtype):
-
- def gen():
- yield element
-
- dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
- batching.map_and_batch(lambda x: x, batch_size=10))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- for _ in range(10):
- self.assertAllEqual([element for _ in range(10)],
- self.evaluate(get_next))
-
-
-class UnbatchDatasetBenchmark(test.Benchmark):
-
- def benchmarkNativeUnbatch(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.apply(batching.unbatch())
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (native) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_native_batch_size_%d" %
- batch_size)
-
- # Include a benchmark of the previous `unbatch()` implementation that uses
- # a composition of more primitive ops. Eventually we'd hope to generate code
- # that is as good in both cases.
- def benchmarkOldUnbatchImplementation(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (unfused) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
- batch_size)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
index 4263a90..af20e50 100644
--- a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
@@ -110,9 +110,9 @@
with self.cached_session() as sess:
batches = []
for _ in range(4):
- batches.append(sess.run(batch))
+ batches.append(self.evaluate(batch))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
+ self.evaluate(batch)
batch_sizes_val = []
lengths_val = []
for batch in batches:
@@ -160,9 +160,9 @@
with self.cached_session() as sess:
batches = []
for _ in range(3):
- batches.append(sess.run(batch))
+ batches.append(self.evaluate(batch))
with self.assertRaisesOpError("bucket_boundaries"):
- sess.run(batch)
+ self.evaluate(batch)
batch_sizes_val = []
lengths_val = []
for batch in batches:
@@ -197,9 +197,9 @@
with self.cached_session() as sess:
batches = []
for _ in range(5):
- batches.append(sess.run(batch))
+ batches.append(self.evaluate(batch))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
+ self.evaluate(batch)
self.assertAllEqual(batches[0], [[1, 0],
[1, 1]])
diff --git a/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
index 6d063ac..7edaab8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
@@ -59,7 +59,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceInt32(self):
host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
@@ -84,7 +84,7 @@
with self.test_session(config=worker_config) as sess:
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -110,7 +110,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceWithPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -136,7 +136,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
@@ -162,7 +162,7 @@
for i in range(10):
self.assertEqual({"a": i}, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyDictToDeviceWithPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
@@ -188,7 +188,7 @@
for i in range(10):
self.assertEqual({"a": i}, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopySparseTensorsToDevice(self):
@@ -222,7 +222,7 @@
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopySparseTensorsToDeviceWithPrefetch(self):
@@ -256,7 +256,7 @@
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpu(self):
if not test_util.is_gpu_available():
@@ -275,7 +275,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpuWithPrefetch(self):
if not test_util.is_gpu_available():
@@ -294,7 +294,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpuWithMap(self):
if not test_util.is_gpu_available():
@@ -330,7 +330,7 @@
self.assertEqual(float(i**2), y)
self.assertEqual(util_compat.as_bytes(str(i)), z)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpuInt32(self):
if not test_util.is_gpu_available():
@@ -348,7 +348,7 @@
self.evaluate(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpuInt32AndPrefetch(self):
if not test_util.is_gpu_available():
@@ -366,7 +366,7 @@
self.evaluate(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpuStrings(self):
if not test_util.is_gpu_available():
@@ -384,7 +384,7 @@
self.evaluate(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpuStringsAndPrefetch(self):
if not test_util.is_gpu_available():
@@ -402,7 +402,7 @@
self.evaluate(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDevicePingPongCPUGPU(self):
if not test_util.is_gpu_available():
@@ -424,7 +424,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -454,7 +454,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceWithReInitAndPrefetch(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -484,7 +484,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available():
@@ -506,7 +506,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testCopyToDeviceGpuWithReInitAndPrefetch(self):
if not test_util.is_gpu_available():
@@ -528,7 +528,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testIteratorGetNextAsOptionalOnGPU(self):
if not test_util.is_gpu_available():
@@ -547,15 +547,16 @@
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError.
with self.assertRaises(errors.FailedPreconditionError):
- sess.run(elem_has_value_t)
+ self.evaluate(elem_has_value_t)
with self.assertRaises(errors.FailedPreconditionError):
- sess.run(elem_value_t)
+ self.evaluate(elem_value_t)
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
self.evaluate(iterator.initializer)
for i in range(3):
- elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
+ elem_has_value, elem_value = self.evaluate(
+ [elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
self.assertEqual(i, elem_value)
@@ -564,7 +565,7 @@
for _ in range(2):
self.assertFalse(self.evaluate(elem_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(elem_value_t)
+ self.evaluate(elem_value_t)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
index fb75be1..b2f1b43 100644
--- a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
@@ -20,14 +20,8 @@
import gzip
import os
-import string
-import tempfile
-import time
import zlib
-import numpy as np
-
-from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
@@ -38,8 +32,6 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
@@ -537,96 +529,5 @@
record_defaults=record_defaults)
-class CsvDatasetBenchmark(test.Benchmark):
- """Benchmarks for the various ways of creating a dataset from CSV files.
- """
- FLOAT_VAL = '1.23456E12'
- STR_VAL = string.ascii_letters * 10
-
- def _setUp(self, str_val):
- # Since this isn't test.TestCase, have to manually create a test dir
- gfile.MakeDirs(googletest.GetTempDir())
- self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
-
- self._num_cols = [4, 64, 256]
- self._num_per_iter = 5000
- self._filenames = []
- for n in self._num_cols:
- fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
- with open(fn, 'wb') as f:
- # Just write 100 rows and use `repeat`... Assumes the cost
- # of creating an iterator is not significant
- row = ','.join([str_val for _ in range(n)])
- f.write('\n'.join([row for _ in range(100)]))
- self._filenames.append(fn)
-
- def _tearDown(self):
- gfile.DeleteRecursively(self._temp_dir)
-
- def _runBenchmark(self, dataset, num_cols, prefix):
- dataset = dataset.skip(self._num_per_iter - 1)
- deltas = []
- for _ in range(10):
- next_element = dataset.make_one_shot_iterator().get_next()
- with session.Session() as sess:
- start = time.time()
- # NOTE: This depends on the underlying implementation of skip, to have
- # the net effect of calling `GetNext` num_per_iter times on the
- # input dataset. We do it this way (instead of a python for loop, or
- # batching N inputs in one iter) so that the overhead from session.run
- # or batch doesn't dominate. If we eventually optimize skip, this has
- # to change.
- sess.run(next_element)
- end = time.time()
- deltas.append(end - start)
- # Median wall time per CSV record read and decoded
- median_wall_time = np.median(deltas) / self._num_per_iter
- print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols,
- median_wall_time))
- self.report_benchmark(
- iters=self._num_per_iter,
- wall_time=median_wall_time,
- name='%s_with_cols_%d' % (prefix, num_cols))
-
- def benchmarkMapWithFloats(self):
- self._setUp(self.FLOAT_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [[0.0]] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv')
- self._tearDown()
-
- def benchmarkMapWithStrings(self):
- self._setUp(self.STR_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [['']] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
- self._tearDown()
-
- def benchmarkCsvDatasetWithFloats(self):
- self._setUp(self.FLOAT_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [[0.0]] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset')
- self._tearDown()
-
- def benchmarkCsvDatasetWithStrings(self):
- self._setUp(self.STR_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [['']] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
- self._tearDown()
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
index 9fe2ee4..d9bbfb9 100644
--- a/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
@@ -56,7 +56,7 @@
results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
def testDenseToSparseBatchDatasetWithUnknownShape(self):
components = np.random.randint(5, size=(40,)).astype(np.int32)
@@ -89,7 +89,7 @@
], results.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
def testDenseToSparseBatchDatasetWithInvalidShape(self):
input_tensor = array_ops.constant([[1]])
@@ -111,13 +111,13 @@
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"incompatible with the row shape"):
- sess.run(get_next)
+ self.evaluate(get_next)
# Initialize with an input tensor that is larger than `row_shape`.
sess.run(init_op, feed_dict={input_tensor: range(13)})
with self.assertRaisesRegexp(errors.DataLossError,
"larger than the row shape"):
- sess.run(get_next)
+ self.evaluate(get_next)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
index 234fd86..768a8d7 100644
--- a/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
@@ -45,7 +45,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def _normalize(self, vec):
return vec / vec.sum()
@@ -71,9 +71,9 @@
with self.cached_session() as sess:
freqs = np.zeros([num_datasets])
for _ in range(num_samples):
- freqs[sess.run(next_element)] += 1
+ freqs[self.evaluate(next_element)] += 1
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
return freqs
@@ -109,7 +109,7 @@
for i in choice_array:
self.assertEqual(words[i], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testErrors(self):
with self.assertRaisesRegexp(ValueError,
diff --git a/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
index 78805bb..f32d1d0 100644
--- a/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
@@ -49,7 +49,7 @@
self.assertEqual((21, (b"b", 2, 38.0)), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
index c6ee88c..4f8cb12 100644
--- a/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
@@ -52,12 +52,12 @@
with session.Session() as sess:
for _ in range(10):
- sess.run(next_element.op)
+ self.evaluate(next_element.op)
deltas = []
for _ in range(100):
start = time.time()
for _ in range(100):
- sess.run(next_element.op)
+ self.evaluate(next_element.op)
end = time.time()
deltas.append(end - start)
diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
index 15396f3..f985650 100644
--- a/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
@@ -42,7 +42,7 @@
got = self.evaluate(get_next)
self.assertEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
def testSum(self):
reducer = grouping.Reducer(
@@ -131,7 +131,7 @@
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
def testTypeMismatch(self):
reducer = grouping.Reducer(
diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
index cfc357b..d5a36e7 100644
--- a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
@@ -301,7 +301,7 @@
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Window size must be greater than zero, but got 0."):
- print(sess.run(get_next))
+ print(self.evaluate(get_next))
def testReduceFuncError(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@@ -325,7 +325,7 @@
with self.cached_session() as sess:
self.evaluate(init_op)
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
+ self.evaluate(get_next)
def testConsumeWindowDatasetMoreThanOnce(self):
components = np.random.randint(50, size=(200,)).astype(np.int64)
diff --git a/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
index cb0fc13..522b196 100644
--- a/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
@@ -51,7 +51,7 @@
for x in [1., 2., 3., 5.]:
self.assertEqual(x, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
def testParallelMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
@@ -69,7 +69,7 @@
for x in [1., 2., 3., 5.]:
self.assertEqual(x, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
def testReadFileIgnoreError(self):
@@ -97,7 +97,7 @@
for filename in filenames:
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Delete one of the files.
os.remove(filenames[0])
@@ -108,7 +108,7 @@
for filename in filenames[1:]:
self.assertEqual(compat.as_bytes(filename), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
index c4076da..0a43603 100644
--- a/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
@@ -46,7 +46,7 @@
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
with self.cached_session() as sess:
- sess.run(materialize)
+ self.evaluate(materialize)
self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
def testIdentityIndexedDataset(self):
@@ -73,7 +73,8 @@
output = self.evaluate(n)
self.assertEqual(i, output)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(n)
+ self.evaluate(n)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
index c6cefa7..109b369 100644
--- a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
@@ -119,7 +119,7 @@
self.assertAllEqual(file_batch, actual_batch["file"])
self.assertAllEqual(record_batch, actual_batch["record"])
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testReadWithFusedShuffleRepeatDataset(self):
num_epochs = 5
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
index 5486369..1f50938 100644
--- a/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
@@ -102,7 +102,7 @@
self.assertAllEqual(expected_features[k], actual_features[k])
with self.assertRaises(errors.OutOfRangeError):
- sess.run(nxt)
+ self.evaluate(nxt)
def _test_dataset(self,
inputs,
@@ -607,8 +607,8 @@
outputs1 = dataset1.make_one_shot_iterator().get_next()
outputs2 = dataset2.make_one_shot_iterator().get_next()
for _ in range(total_records // batch_size):
- batch1 = nest.flatten(sess.run(outputs1))
- batch2 = nest.flatten(sess.run(outputs2))
+ batch1 = nest.flatten(self.evaluate(outputs1))
+ batch2 = nest.flatten(self.evaluate(outputs2))
for i in range(len(batch1)):
self.assertAllEqual(batch1[i], batch2[i])
@@ -639,8 +639,8 @@
outputs2 = dataset2.make_one_shot_iterator().get_next()
all_equal = False
for _ in range(total_records // batch_size):
- batch1 = nest.flatten(sess.run(outputs1))
- batch2 = nest.flatten(sess.run(outputs2))
+ batch1 = nest.flatten(self.evaluate(outputs1))
+ batch2 = nest.flatten(self.evaluate(outputs2))
for i in range(len(batch1)):
all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
self.assertFalse(all_equal)
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
index 404edf2..0bb7b7c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
@@ -135,7 +135,7 @@
interleave_cycle_length=num_parallel_reads,
drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(outputs)
+ self.evaluate(outputs)
def testRead(self):
for batch_size in [1, 2]:
@@ -192,7 +192,7 @@
first_batches = []
try:
while True:
- first_batches.append(sess.run(next_element))
+ first_batches.append(self.evaluate(next_element))
except errors.OutOfRangeError:
pass
@@ -200,7 +200,7 @@
second_batches = []
try:
while True:
- second_batches.append(sess.run(next_element))
+ second_batches.append(self.evaluate(next_element))
except errors.OutOfRangeError:
pass
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
index b4bc4a6..8449c06 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -95,7 +95,7 @@
self.assertAllEqual(component[(i * 14 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Batch of a finite input, where the batch_size does not
# divide the total number of elements.
@@ -115,12 +115,12 @@
self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
result_component[j])
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Batch of an empty input should fail straight away.
sess.run(init_op, feed_dict={count: 0, batch_size: 8})
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Empty batch should be an initialization time error.
with self.assertRaises(errors.InvalidArgumentError):
@@ -157,7 +157,7 @@
if not drop_remainder:
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
@parameterized.named_parameters(
("Normal", False),
@@ -181,7 +181,7 @@
self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
self.assertAllEqual([[64], [81]], self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
@parameterized.named_parameters(
("Normal", False),
@@ -208,7 +208,7 @@
expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
+ self.evaluate(elements)
@parameterized.named_parameters(
("Normal", False),
@@ -237,7 +237,7 @@
expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
self.assertAllEqual(got, expected)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
+ self.evaluate(elements)
@parameterized.named_parameters(
("Normal", False),
@@ -271,7 +271,7 @@
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
@parameterized.named_parameters(
("Normal", False),
@@ -324,7 +324,7 @@
self.evaluate(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
- sess.run(get_next)
+ self.evaluate(get_next)
@parameterized.named_parameters(
("Normal", False),
@@ -354,7 +354,7 @@
with self.cached_session() as sess:
for _ in range(3):
- sess.run(get_next)
+ self.evaluate(get_next)
@parameterized.named_parameters(
("1", 0, False),
@@ -398,9 +398,9 @@
if threshold % 10 != 0:
self.assertAllEqual(
[threshold // 10 * 10 + j for j in range(threshold % 10)],
- sess.run(get_next))
+ self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
@parameterized.named_parameters(
("1", False, dtypes.bool, False),
@@ -503,13 +503,13 @@
print("Case %d" % i)
if i < 5:
self.assertAllEqual([i * 10 + j + 1 for j in range(10)],
- sess.run(get_next))
+ self.evaluate(get_next))
else:
self.assertAllEqual(
[((i * 10) + j) * ((i * 10) + j) for j in range(10)],
- sess.run(get_next))
+ self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
index 3cf3b89..6042ca1 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -260,10 +260,10 @@
with session.Session() as sess:
# Warm up the session
for _ in range(5):
- sess.run(op)
+ self.evaluate(op)
start = time.time()
for _ in range(num_iters):
- sess.run(op)
+ self.evaluate(op)
end = time.time()
mean_us = (end - start) * 1e6 / num_iters
self.report_benchmark(
diff --git a/tensorflow/python/data/kernel_tests/matching_files_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
similarity index 62%
rename from tensorflow/python/data/kernel_tests/matching_files_dataset_op_test.py
rename to tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
index 4d86ec4..938dd4a 100644
--- a/tensorflow/python/data/kernel_tests/matching_files_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/matching_files_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the private `MatchingFilesDataset`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -20,20 +20,15 @@
import os
import shutil
import tempfile
-import time
-import numpy as np
-
-from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import matching_files
from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops.dataset_ops import MatchingFilesDataset
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class MatchingFilesDatasetTest(test_base.DatasetTestBase):
+class MatchingFilesTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
@@ -46,31 +41,34 @@
open(os.path.join(self.tmp_dir, filename), 'a').close()
def testNonExistingDirectory(self):
- """Test the MatchingFiles dataset with a non-existing directory"""
+ """Test the MatchingFiles dataset with a non-existing directory."""
self.tmp_dir = os.path.join(self.tmp_dir, 'nonexistingdir')
- dataset = MatchingFilesDataset(os.path.join(self.tmp_dir, '*'))
+ dataset = matching_files.MatchingFilesDataset(
+ os.path.join(self.tmp_dir, '*'))
with self.cached_session() as sess:
next_element = dataset.make_one_shot_iterator().get_next()
with self.assertRaises(errors.NotFoundError):
sess.run(next_element)
def testEmptyDirectory(self):
- """Test the MatchingFiles dataset with an empty directory"""
+ """Test the MatchingFiles dataset with an empty directory."""
- dataset = MatchingFilesDataset(os.path.join(self.tmp_dir, '*'))
+ dataset = matching_files.MatchingFilesDataset(
+ os.path.join(self.tmp_dir, '*'))
with self.cached_session() as sess:
next_element = dataset.make_one_shot_iterator().get_next()
with self.assertRaises(errors.NotFoundError):
sess.run(next_element)
def testSimpleDirectory(self):
- """Test the MatchingFiles dataset with a simple directory"""
+ """Test the MatchingFiles dataset with a simple directory."""
filenames = ['a', 'b', 'c']
self._touchTempFiles(filenames)
- dataset = MatchingFilesDataset(os.path.join(self.tmp_dir, '*'))
+ dataset = matching_files.MatchingFilesDataset(
+ os.path.join(self.tmp_dir, '*'))
with self.cached_session() as sess:
next_element = dataset.make_one_shot_iterator().get_next()
@@ -86,12 +84,13 @@
sess.run(next_element)
def testFileSuffixes(self):
- """Test the MatchingFiles dataset using the suffixes of filename"""
+ """Test the MatchingFiles dataset using the suffixes of filename."""
filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
self._touchTempFiles(filenames)
- dataset = MatchingFilesDataset(os.path.join(self.tmp_dir, '*.py'))
+ dataset = matching_files.MatchingFilesDataset(
+ os.path.join(self.tmp_dir, '*.py'))
with self.cached_session() as sess:
next_element = dataset.make_one_shot_iterator().get_next()
expected_filenames = []
@@ -106,12 +105,13 @@
sess.run(next_element)
def testFileMiddles(self):
- """Test the MatchingFiles dataset using the middles of filename"""
+ """Test the MatchingFiles dataset using the middles of filename."""
filenames = ['aa.txt', 'bb.py', 'bbc.pyc', 'cc.pyc']
self._touchTempFiles(filenames)
- dataset = MatchingFilesDataset(os.path.join(self.tmp_dir, 'b*.py*'))
+ dataset = matching_files.MatchingFilesDataset(
+ os.path.join(self.tmp_dir, 'b*.py*'))
with self.cached_session() as sess:
next_element = dataset.make_one_shot_iterator().get_next()
expected_filenames = []
@@ -126,7 +126,7 @@
sess.run(next_element)
def testNestedDirectories(self):
- """Test the MatchingFiles dataset with nested directories"""
+ """Test the MatchingFiles dataset with nested directories."""
filenames = []
width = 8
@@ -147,7 +147,7 @@
suffix) for suffix in ['*.txt', '*.log']
]
- dataset = MatchingFilesDataset(patterns)
+ dataset = matching_files.MatchingFilesDataset(patterns)
with self.cached_session() as sess:
next_element = dataset.make_one_shot_iterator().get_next()
expected_filenames = [
@@ -165,70 +165,5 @@
self.assertItemsEqual(expected_filenames, actual_filenames)
-class MatchingFilesDatasetBenchmark(test.Benchmark):
-
- def benchmarkNestedDirectories(self):
- tmp_dir = tempfile.mkdtemp()
- width = 500
- depth = 10
- for i in range(width):
- for j in range(depth):
- new_base = os.path.join(tmp_dir, str(i),
- *[str(dir_name) for dir_name in range(j)])
- os.makedirs(new_base)
- child_files = ['a.py', 'b.pyc'] if j < depth - 1 else ['c.txt', 'd.log']
- for f in child_files:
- filename = os.path.join(new_base, f)
- open(filename, 'w').close()
-
- patterns = [
- os.path.join(tmp_dir, os.path.join(*['**'
- for _ in range(depth)]), suffix)
- for suffix in ['*.txt', '*.log']
- ]
-
- deltas = []
- iters = 3
- for _ in range(iters):
- with ops.Graph().as_default():
- dataset = MatchingFilesDataset(patterns)
- next_element = dataset.make_one_shot_iterator().get_next()
-
- with session.Session() as sess:
- sub_deltas = []
- while True:
- try:
- start = time.time()
- sess.run(next_element)
- end = time.time()
- sub_deltas.append(end - start)
- except errors.OutOfRangeError:
- break
- deltas.append(sub_deltas)
-
- median_deltas = np.median(deltas, axis=0)
- print('Nested directory size (width*depth): %d*%d Median wall time: '
- '%fs (read first filename), %fs (read second filename), avg %fs'
- ' (read %d more filenames)' %
- (width, depth, median_deltas[0], median_deltas[1],
- np.average(median_deltas[2:]), len(median_deltas) - 2))
- self.report_benchmark(
- iters=iters,
- wall_time=np.sum(median_deltas),
- extras={
- 'read first file:':
- median_deltas[0],
- 'read second file:':
- median_deltas[1],
- 'avg time for reading %d more filenames:' %
- (len(median_deltas) - 2):
- np.average(median_deltas[2:])
- },
- name='benchmark_matching_files_dataset_nesteddirectory(%d*%d)' %
- (width, depth))
-
- shutil.rmtree(tmp_dir, ignore_errors=True)
-
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index 1d0e6af..121798a 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -221,15 +221,14 @@
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
- "//tensorflow/python:nn_ops",
+ "//tensorflow/python:nn",
"//tensorflow/python:parsing_ops",
- "//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
@@ -249,12 +248,9 @@
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
index c2665cf..fc65f52 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
@@ -60,7 +60,8 @@
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1)
options = dataset_ops.Options()
- options.experimental_stats = stats_options.StatsOptions(aggregator)
+ options.experimental_stats = stats_options.StatsOptions()
+ options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
self.assertDatasetProduces(
dataset,
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index 470de58..4f05f02 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -17,18 +17,14 @@
from __future__ import division
from __future__ import print_function
-import time
-
from absl.testing import parameterized
import numpy as np
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
-from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -439,102 +435,5 @@
("IteratorGetNext", "IteratorGetNext_1", 1)])
-class MapVectorizationBenchmark(test.Benchmark):
- # TODO(rachelim): Add a benchmark for more expensive transformations, such as
- # vgg_preprocessing.
-
- def _run(self, x, num_iters=100, name=None):
- deltas = []
- with session.Session() as sess:
- for _ in range(5):
- # Warm up session...
- sess.run(x)
- for _ in range(num_iters):
- start = time.time()
- sess.run(x)
- end = time.time()
- deltas.append(end - start)
- median_time = np.median(deltas)
- self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
- return median_time
-
- def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
- num_elems = int(np.sum([np.prod(x) for x in input_size]))
- name_template = "{}__batch_size_{}_input_element_size_{}_{}"
- unoptimized = input_dataset.map(map_fn).batch(batch_size)
- unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
-
- optimized = input_dataset.map(map_fn).batch(batch_size)
- options = dataset_ops.Options()
- options.experimental_map_vectorization = True
- optimized = optimized.with_options(options)
- optimized_op = optimized.make_one_shot_iterator().get_next()
-
- unoptimized_time = self._run(
- unoptimized_op,
- name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
- optimized_time = self._run(
- optimized_op,
- name=name_template.format(str_id, batch_size, num_elems, "optimized"))
-
- print("Batch size: {}\n"
- "Input element size: {}\n"
- "Transformation: {}\n"
- "Speedup: {}\n".format(batch_size, input_size, str_id,
- (unoptimized_time / optimized_time)))
-
- # Known cheap functions
- def benchmarkIdentity(self):
- self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
- "identity")
-
- def benchmarkAddConst(self):
- self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
-
- def benchmarkReturnConst(self):
- self._benchmark_helper(lambda *args: [constant_op.constant(2)], "ret_const")
-
- def benchmarkSelect(self):
- self._benchmark_helper(lambda *args: args[0], "select")
-
- def benchmarkCast(self):
- self._benchmark_helper(
- lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
-
- def benchmarkReshape(self):
- self._benchmark_helper(
- lambda *args: [array_ops.reshape(x, (-1, 30)) for x in args], "reshape")
-
- def benchmarkDecodeCSV(self):
- csv_fn, csv_factory = _generate_csv_test_case()
- self._benchmark_helper(csv_fn, "decode_csv", lambda: [csv_factory()])
-
- def benchmarkParseSingleExample(self):
- # NOTE: Since we haven't implemented a vectorizer for "SerializeSparse",
- # this function is only naively vectorized.
- parse_fn, parse_factory = _generate_parse_single_example_test_case()
-
- self._benchmark_helper(parse_fn, "parse_single_example",
- lambda: [parse_factory()])
-
- def _default_dataset_factory(self):
- input_sizes = [(10, 10, 3), (10, 100, 300)]
- for sz in input_sizes:
- yield dataset_ops.Dataset.from_tensor_slices(np.random.rand(*sz))
-
- def _benchmark_helper(self, map_fn, str_id, base_dataset_factory=None):
- if base_dataset_factory is None:
- base_dataset_factory = self._default_dataset_factory
-
- batch_size = 1000
- for base_dataset in base_dataset_factory():
- base_dataset = base_dataset.repeat()
- input_size = [
- tuple(shape.as_list())
- for shape in nest.flatten(base_dataset.output_shapes)
- ]
- self._compare(base_dataset, map_fn, batch_size, input_size, str_id)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_test.py
index f5a8399..d3c1214 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_test.py
@@ -17,182 +17,18 @@
from __future__ import division
from __future__ import print_function
-import time
-
from absl.testing import parameterized
-import numpy as np
-from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
-from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
# TODO(b/117581999): Add eager coverage for the following tests.
class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
- def testModelMap(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(math_ops.matmul)
- dataset = dataset_ops._ModelDataset(dataset)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(100):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelParallelMap(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(
- math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops._ModelDataset(dataset)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(100):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- @parameterized.named_parameters(
- ("Default", False),
- ("NUMA", True),
- )
- def testModelMapAndBatch(self, numa_aware):
- batch_size = 16
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.apply(
- batching.map_and_batch(
- math_ops.matmul,
- num_parallel_calls=optimization.AUTOTUNE,
- batch_size=batch_size))
- dataset = dataset_ops._ModelDataset(dataset)
- options = dataset_ops.Options()
- options.experimental_numa_aware = numa_aware
- dataset = dataset.with_options(options)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(10):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelParallelInterleave(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(math_ops.matmul)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset,
- cycle_length=10,
- num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops._ModelDataset(dataset)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(100):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelNested(self):
- k = 1024 * 1024
- a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
- b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
- c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
- dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
-
- def f1(a, b, c):
- x, y = a
- return math_ops.matmul(x, y), b, c
-
- def f2(a, b, c):
- x, y = b
- return a, math_ops.matmul(x, y), c
-
- def f3(a, b, c):
- x, y = c
- return a, b, math_ops.matmul(x, y)
-
- dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset, cycle_length=2)
-
- dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset, cycle_length=2)
-
- dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops._ModelDataset(dataset)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next)
- for _ in range(100):
- start = time.time()
- sess.run(get_next)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
def testAutotuneOption(self):
dataset = dataset_ops.Dataset.from_tensors(0)
dataset = dataset.map(lambda x: x).apply(
@@ -205,9 +41,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.assertEqual(0, sess.run(get_next))
+ self.assertEqual(0, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py
index 510b197..df26a2c0 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py
@@ -51,7 +51,7 @@
with self.cached_session() as sess:
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
- sess.run(get_next)
+ self.evaluate(get_next)
# TODO(b/117581999): Add eager coverage for the following tests.
def testSkipEagerOptimizationLargeInputFromTensorSlices(self):
@@ -64,7 +64,7 @@
with self.cached_session() as sess:
sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
- sess.run(get_next)
+ self.evaluate(get_next)
def testOptimizationNestedDataset(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
index ca8bc5f..1dfe854 100644
--- a/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
@@ -22,6 +22,7 @@
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.experimental.ops import threadpool
from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.kernel_tests import test_base
@@ -35,18 +36,7 @@
class OverrideThreadpoolTest(test_base.DatasetTestBase,
parameterized.TestCase):
- @parameterized.named_parameters(
- ("1", 1, None),
- ("2", 2, None),
- ("3", 4, None),
- ("4", 8, None),
- ("5", 16, None),
- ("6", 4, -1),
- ("7", 4, 0),
- ("8", 4, 1),
- ("9", 4, 4),
- )
- def testNumThreads(self, num_threads, max_intra_op_parallelism):
+ def _testNumThreadsHelper(self, num_threads, override_threadpool_fn):
def get_thread_id(_):
# Python creates a dummy thread object to represent the current
@@ -60,14 +50,7 @@
dataset_ops.Dataset.range(1000).map(
lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
num_parallel_calls=32).apply(unique.unique()))
-
- dataset = threadpool.override_threadpool(
- dataset,
- threadpool.PrivateThreadPool(
- num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name="private_thread_pool_%d" % num_threads))
-
+ dataset = override_threadpool_fn(dataset)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
@@ -76,15 +59,67 @@
thread_ids = []
try:
while True:
- thread_ids.append(sess.run(next_element))
+ thread_ids.append(self.evaluate(next_element))
except errors.OutOfRangeError:
pass
- self.assertEqual(len(thread_ids), len(set(thread_ids)))
- self.assertGreater(len(thread_ids), 0)
- # NOTE(mrry): We don't control the thread pool scheduling, and
- # so cannot guarantee that all of the threads in the pool will
- # perform work.
- self.assertLessEqual(len(thread_ids), num_threads)
+ self.assertLen(thread_ids, len(set(thread_ids)))
+ self.assertNotEmpty(thread_ids)
+ if num_threads:
+ # NOTE(mrry): We don't control the thread pool scheduling, and
+ # so cannot guarantee that all of the threads in the pool will
+ # perform work.
+ self.assertLessEqual(len(thread_ids), num_threads)
+
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", 4, -1),
+ ("7", 4, 0),
+ ("8", 4, 1),
+ ("9", 4, 4),
+ )
+ def testNumThreadsDeprecated(self, num_threads, max_intra_op_parallelism):
+
+ def override_threadpool_fn(dataset):
+ return threadpool.override_threadpool(
+ dataset,
+ threadpool.PrivateThreadPool(
+ num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name="private_thread_pool_%d" % num_threads))
+
+ self._testNumThreadsHelper(num_threads, override_threadpool_fn)
+
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", None, 0),
+ ("7", None, 1),
+ ("8", None, 4),
+ ("9", 4, 0),
+ ("10", 4, 1),
+ ("11", 4, 4),
+ ("12", None, None),
+ )
+ def testNumThreads(self, num_threads, max_intra_op_parallelism):
+
+ def override_threadpool_fn(dataset):
+ t_options = threading_options.ThreadingOptions()
+ if max_intra_op_parallelism is not None:
+ t_options.max_intra_op_parallelism = max_intra_op_parallelism
+ if num_threads is not None:
+ t_options.private_threadpool_size = num_threads
+ options = dataset_ops.Options()
+ options.experimental_threading = t_options
+ return dataset.with_options(options)
+
+ self._testNumThreadsHelper(num_threads, override_threadpool_fn)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
index 91908f5..77f0dc8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
@@ -195,9 +195,9 @@
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
self.write_coordination_events[expected_element].set()
self.assertEqual(expected_element * expected_element,
- sess.run(self.next_element))
+ self.evaluate(self.next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testSingleThreaded(self):
self._testSingleThreaded()
@@ -235,10 +235,10 @@
for expected_element in self._interleave(
[[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
self.write_coordination_events[expected_element].set()
- output = sess.run(self.next_element)
+ output = self.evaluate(self.next_element)
self.assertEqual(expected_element * expected_element, output)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1.
@@ -262,7 +262,7 @@
self.write_coordination_events[expected_element].set()
if done_first_event: # First event starts the worker threads.
self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
if not done_first_event:
self.read_coordination_events[expected_element].acquire()
done_first_event = True
@@ -270,7 +270,7 @@
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testTwoThreadsNoContention(self):
self._testTwoThreadsNoContention()
@@ -309,7 +309,7 @@
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
if not done_first_event:
done_first_event = True
self.assertTrue(
@@ -318,7 +318,7 @@
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testTwoThreadsNoContentionWithRaces(self):
self._testTwoThreadsNoContentionWithRaces()
@@ -348,7 +348,7 @@
self.write_coordination_events[expected_element].set()
if done_first_event: # First event starts the worker threads.
self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
if not done_first_event:
done_first_event = True
self.read_coordination_events[expected_element].acquire()
@@ -356,7 +356,7 @@
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testTwoThreadsNoContentionBlockLength(self):
self._testTwoThreadsNoContentionBlockLength()
@@ -396,7 +396,7 @@
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
if not done_first_event:
done_first_event = True
self.assertTrue(
@@ -405,7 +405,7 @@
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testTwoThreadsNoContentionWithRacesAndBlocking(self):
self._testTwoThreadsNoContentionWithRacesAndBlocking()
@@ -428,7 +428,7 @@
self.prefetch_input_elements: 0,
})
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testEmptyInput(self):
self._testEmptyInput()
@@ -451,7 +451,7 @@
self.prefetch_input_elements: 0,
})
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testNonEmptyInputIntoEmptyOutputs(self):
self._testNonEmptyInputIntoEmptyOutputs()
@@ -484,7 +484,7 @@
# presence of finishing iterators.
if done_first_event and not (sloppy and (i in race_indices)):
self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
if not done_first_event or (sloppy and (i in race_indices)):
done_first_event = True
self.read_coordination_events[expected_element].acquire()
@@ -520,10 +520,10 @@
]
for element in mis_ordering:
self.write_coordination_events[element].set()
- self.assertEqual(element * element, sess.run(self.next_element))
+ self.assertEqual(element * element, self.evaluate(self.next_element))
self.assertTrue(self.read_coordination_events[element].acquire(False))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testBlockLengthWithContentionSloppy(self):
with self.cached_session() as sess:
@@ -549,7 +549,7 @@
self.write_coordination_events[expected_element].set()
if done_first_event: # First event starts the worker threads.
self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
if not done_first_event:
self.read_coordination_events[expected_element].acquire()
done_first_event = True
@@ -557,7 +557,7 @@
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def _testEarlyExit(self, sloppy=False):
# Exiting without consuming all input should not block
@@ -575,7 +575,7 @@
})
for i in range(4, 7):
self.write_coordination_events[i].set()
- elem = sess.run(self.next_element) # Start all workers
+ elem = self.evaluate(self.next_element) # Start all workers
# Allow the one successful worker to progress beyond the py_func again.
elem = int(math.sqrt(elem))
self.write_coordination_events[elem].set()
@@ -608,7 +608,7 @@
with self.cached_session() as sess:
output_values = []
for _ in range(30):
- output_values.append(sess.run(iterator.get_next()))
+ output_values.append(self.evaluate(iterator.get_next()))
expected_values = self._interleave(
[[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
@@ -643,7 +643,7 @@
expected = [i, 0] if j % 2 == 0 else [0, -i]
self.assertAllEqual(expected, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
def testErrorsInOutputFn(self):
with self.cached_session() as sess:
@@ -668,15 +668,15 @@
self.error = ValueError()
self.write_coordination_events[expected_element].set()
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
else:
self.write_coordination_events[expected_element].set()
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
self.assertEqual(expected_element * expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testErrorsInInputFn(self):
@@ -720,14 +720,14 @@
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
if expected_element == 5:
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
else:
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
self.assertEqual(expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testErrorsInInterleaveFn(self):
@@ -769,14 +769,14 @@
self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
if expected_element == 5:
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
else:
- actual_element = sess.run(self.next_element)
+ actual_element = self.evaluate(self.next_element)
self.assertEqual(expected_element, actual_element,
"At index %s: %s expected, got: %s" %
(i, expected_element, actual_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
+ self.evaluate(self.next_element)
def testShutdownRace(self):
dataset = dataset_ops.Dataset.range(20)
@@ -799,7 +799,7 @@
self.evaluate(iterator.initializer)
try:
while True:
- elements.extend(sess.run(next_element))
+ elements.extend(self.evaluate(next_element))
except errors.OutOfRangeError:
pass
results.append(elements)
diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
index 60c3741..8fc18e1 100644
--- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
@@ -59,7 +59,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testPrefetchToSameDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -89,7 +89,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testPrefetchDictToDevice(self):
host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
@@ -119,7 +119,7 @@
for i in range(10):
self.assertEqual({"a": i}, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testPrefetchSparseTensorsToDevice(self):
def make_tensor(i):
@@ -155,7 +155,7 @@
self.assertAllEqual([[0, 0]], actual.indices)
self.assertAllEqual([2, 2], actual.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
@@ -172,7 +172,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -206,7 +206,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testPrefetchToDeviceGpuWithReInit(self):
if not test_util.is_gpu_available():
@@ -227,7 +227,7 @@
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
index 0e9bb46..dc8a7bc 100644
--- a/tensorflow/python/data/experimental/kernel_tests/scan_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
@@ -62,7 +62,7 @@
itertools.count(start_val, step_val), range(take_val)):
self.assertEqual(expected, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
@test_util.run_in_graph_and_eager_modes
def testFibonacci(self):
@@ -112,7 +112,7 @@
itertools.count(start_val, step_val), range(take_val)):
self.assertEqual(expected, self.evaluate(next_element).values[0])
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testChangingStateShape(self):
# Test the fixed-point shape invariant calculations: start with
@@ -140,7 +140,7 @@
self.assertAllEqual([0] * (2**i), longer_vector_val)
self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testIncorrectStateType(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index 2cfb575..c724987 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -74,7 +74,10 @@
size = "small",
srcs = ["checkpoint_input_pipeline_hook_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_pip",
+ "notsan",
+ ],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -358,6 +361,7 @@
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:matching_files",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
index 225f6cb..e3ba8ad 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
@@ -17,8 +17,6 @@
from __future__ import division
from __future__ import print_function
-import numpy as np
-
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
@@ -35,7 +33,7 @@
def testFilterCore(self):
div = 3
- num_outputs = np.sum([x % 3 != 2 for x in range(100)])
+ num_outputs = sum(x % 3 != 2 for x in range(100))
self.run_core_tests(lambda: self._build_filter_range_graph(div),
lambda: self._build_filter_range_graph(div * 2),
num_outputs)
@@ -47,7 +45,7 @@
lambda d: d["foo"] + d["bar"])
def testFilterDictCore(self):
- num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)])
+ num_outputs = sum((x**2) % 2 == 0 for x in range(10))
self.run_core_tests(self._build_filter_dict_graph, None, num_outputs)
def _build_sparse_filter(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py
index 7edb200..c026e97 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/matching_files_dataset_serialization_test.py
@@ -22,7 +22,7 @@
import tempfile
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops.dataset_ops import MatchingFilesDataset
+from tensorflow.python.data.experimental.ops import matching_files
from tensorflow.python.platform import test
@@ -30,7 +30,7 @@
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_iterator_graph(self, test_patterns):
- return MatchingFilesDataset(test_patterns)
+ return matching_files.MatchingFilesDataset(test_patterns)
def testMatchingFilesCore(self):
tmp_dir = tempfile.mkdtemp()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
index 704a407..aeb338d 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
@@ -85,7 +85,7 @@
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Saving and restoring in same session.
with ops.Graph().as_default() as g:
@@ -100,7 +100,7 @@
for i in range(break_point, stop):
self.assertEqual(i, self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
def _build_range_dataset(self, start, stop):
return dataset_ops.Dataset.range(start, stop)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
index 496fd45..12fa098 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
@@ -60,7 +60,7 @@
init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
num_outputs)
with self.session(graph=g) as sess:
- sess.run(init_ops)
+ self.evaluate(init_ops)
for _ in range(break_point):
output = self.evaluate(get_next_ops)
for i in range(num_pipelines):
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
index a04f1dd..e753a7a 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
@@ -138,9 +138,9 @@
saver = saver_lib.Saver(allow_empty=True)
with self.session(graph=g) as sess:
self._save(sess, saver)
- expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ expected = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
self._restore(saver, sess)
- actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ actual = [self.evaluate(get_next_ops) for _ in range(num_outputs)]
self.match(expected, actual)
diff --git a/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
index 5f7d905..2e8b93f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
@@ -38,10 +38,10 @@
outputs = []
with self.cached_session() as sess:
for _ in range(num_outputs):
- outputs.append(sess.run(get_next))
+ outputs.append(self.evaluate(get_next))
if verify_exhausted:
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
return outputs
def testCorrectOutput(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/sleep_test.py b/tensorflow/python/data/experimental/kernel_tests/sleep_test.py
index f7d42bc..1a6d552 100644
--- a/tensorflow/python/data/experimental/kernel_tests/sleep_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sleep_test.py
@@ -45,7 +45,7 @@
end_time = time.time()
self.assertGreater(end_time - start_time, (10 * sleep_microseconds) / 1e6)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
index e11bad7..eb66927 100644
--- a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
@@ -43,7 +43,7 @@
self.assertEqual((b"Jane", b"Moe", b"Hi again!"),
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that SqlDataset works on a join query.
def testReadResultSetJoinQuery(self):
@@ -62,7 +62,7 @@
self.assertEqual((b"John", b"California", b"Hi!"),
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that SqlDataset can read a database entry with a null-terminator
# in the middle of the text and place the entry in a `string` tensor.
@@ -81,7 +81,7 @@
self.assertEqual((b"Jane", b"Moe", b"nonsense\0"),
self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that SqlDataset works when used on two different queries.
# Because the output types of the dataset must be determined at graph-creation
@@ -99,7 +99,7 @@
self.assertEqual((b"John", b"Doe", b"Hi!"), self.evaluate(get_next))
self.assertEqual((b"Jane", b"Moe", b"Hi again!"), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
sess.run(
init_op,
feed_dict={
@@ -109,9 +109,9 @@
self.assertEqual((b"John", b"Doe", b"California"),
self.evaluate(get_next))
self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
- sess.run(get_next))
+ self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that an `OutOfRangeError` is raised on the first call to
# `get_next_str_only` if result set is empty.
@@ -126,7 +126,7 @@
"WHERE first_name = 'Nonexistent'"
})
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that an error is raised when `driver_name` is invalid.
def testReadResultSetWithInvalidDriverName(self):
@@ -155,7 +155,7 @@
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.UnknownError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that an error is raised when there is a syntax error in `query`.
def testReadResultSetOfQueryWithSyntaxError(self):
@@ -170,7 +170,7 @@
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.UnknownError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that an error is raised when the number of columns in `query`
# does not match the length of `output_types`.
@@ -185,7 +185,7 @@
"ORDER BY first_name DESC"
})
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that no results are returned when `query` is an insert query rather
# than a select query. In particular, the error refers to the number of
@@ -203,7 +203,7 @@
"VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
})
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int8` tensor.
@@ -219,7 +219,7 @@
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int8` tensor.
@@ -236,7 +236,7 @@
})
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int8` tensor.
@@ -254,7 +254,7 @@
# Max and min values of int8
self.assertEqual((127, -128), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int16` tensor.
@@ -270,7 +270,7 @@
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int16` tensor.
@@ -287,7 +287,7 @@
})
self.assertEqual((b"John", 0, -2), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int16` tensor.
@@ -305,7 +305,7 @@
# Min value of int16
self.assertEqual((b"Jane", -32768), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int32` tensor.
@@ -335,7 +335,7 @@
self.assertEqual((b"John", 0), self.evaluate(get_next))
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int32` tensor.
@@ -353,7 +353,7 @@
# Min value of int32
self.assertEqual((b"Jane", -2147483648), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
# table and place it in an `int32` tensor.
@@ -369,7 +369,7 @@
self.assertEqual((b"John", 123), self.evaluate(get_next))
self.assertEqual((b"Jane", 1000), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in an `int64` tensor.
@@ -385,7 +385,7 @@
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int64` tensor.
@@ -401,7 +401,7 @@
self.assertEqual((b"John", 0), self.evaluate(get_next))
self.assertEqual((b"Jane", -20000), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int64` tensor.
@@ -420,7 +420,7 @@
# Min value of int64
self.assertEqual((b"Jane", -9223372036854775808), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in a `uint8` tensor.
@@ -436,7 +436,7 @@
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
# SQLite database table and place them in `uint8` tensors.
@@ -454,7 +454,7 @@
# Max value of uint8
self.assertEqual((b"Jane", 255), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in a `uint16` tensor.
@@ -470,7 +470,7 @@
self.assertEqual((b"John", 9), self.evaluate(get_next))
self.assertEqual((b"Jane", 127), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
# SQLite database table and place them in `uint16` tensors.
@@ -488,7 +488,7 @@
# Max value of uint16
self.assertEqual((b"Jane", 65535), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
# SQLite database table and place them as `True` and `False` respectively
@@ -506,7 +506,7 @@
self.assertEqual((b"John", True), self.evaluate(get_next))
self.assertEqual((b"Jane", False), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
# from a SQLite database table and place it as `True` in a `bool` tensor.
@@ -522,7 +522,7 @@
self.assertEqual((b"John", True), self.evaluate(get_next))
self.assertEqual((b"Jane", True), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table
# and place it in a `float64` tensor.
@@ -541,7 +541,7 @@
self.evaluate(get_next))
self.assertEqual((b"John", b"Adams", -19.95), self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table beyond
# the precision of 64-bit IEEE, without throwing an error. Test that
@@ -560,13 +560,13 @@
self.assertEqual(
(b"George", b"Washington",
1331241.321342132321324589798264627463827647382647382643874),
- sess.run(get_next))
+ self.evaluate(get_next))
self.assertEqual(
(b"John", b"Adams",
1331241321342132321324589798264627463827647382647382643874.0),
- sess.run(get_next))
+ self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table,
# representing the largest integer representable as a 64-bit IEEE float
@@ -584,11 +584,11 @@
"ORDER BY first_name"
})
self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
- sess.run(get_next))
+ self.evaluate(get_next))
self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
- sess.run(get_next))
+ self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ self.evaluate(get_next)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index 958c3f0..e816006 100644
--- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -45,22 +45,18 @@
def function_apply_options(dataset, aggregator, prefix="", counter_prefix=""):
options = dataset_ops.Options()
- options.experimental_stats = stats_options.StatsOptions(aggregator)
+ options.experimental_stats = stats_options.StatsOptions()
+ options.experimental_stats.aggregator = aggregator
+ options.experimental_stats.prefix = prefix
+ options.experimental_stats.counter_prefix = counter_prefix
options.experimental_stats.latency_all_edges = False
- if prefix:
- options.experimental_stats.prefix = prefix
- if counter_prefix:
- options.experimental_stats.counter_prefix = counter_prefix
return dataset.with_options(options)
@parameterized.named_parameters(
- dict(
- testcase_name="SetStatsAggregator",
- dataset_transformation=function_set_stats_aggregator),
- dict(
- testcase_name="StatsOptions",
- dataset_transformation=function_apply_options))
+ ("SetStatsAggregator", function_set_stats_aggregator),
+ ("StatsOptions", function_apply_options),
+)
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
def testBytesProduced(self, dataset_transformation):
@@ -78,13 +74,13 @@
expected_sum = 0.0
for i in range(100):
self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
expected_sum += i * 8.0
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
@@ -103,9 +99,9 @@
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(i + 1))
+ self.evaluate(summary_t), "record_latency", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 100.0)
@@ -122,7 +118,7 @@
self.evaluate(iterator.initializer)
for i in range(100):
self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
float(i + 1))
@@ -131,7 +127,7 @@
self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
0, 1)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
@@ -149,14 +145,14 @@
self.evaluate(iterator.initializer)
for i in range(10):
self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
summary_str = self.evaluate(summary_t)
self._assertSummaryHasScalarValue(summary_str,
"Prefetch::buffer_capacity", 0)
self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
0)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testFilteredElementsStats(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@@ -173,15 +169,16 @@
self.assertEqual(i * 3, self.evaluate(next_element))
if i is not 0:
self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
+ self.evaluate(summary_t), "Filter::dropped_elements",
+ float(i * 2))
self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
+ self.evaluate(summary_t), "Filter::filtered_elements", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::dropped_elements", 67.0)
+ self.evaluate(summary_t), "Filter::dropped_elements", 67.0)
self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::filtered_elements", 34.0)
+ self.evaluate(summary_t), "Filter::filtered_elements", 34.0)
def testMapBufferUtilization(self, dataset_transformation):
@@ -266,11 +263,12 @@
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
+ self.evaluate(summary_t), "record_latency",
+ float((j * 100) + i + 1))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", (j + 1) * 100.0)
+ self.evaluate(summary_t), "record_latency", (j + 1) * 100.0)
def testNoAggregatorRegistered(self, dataset_transformation):
dataset = dataset_ops.Dataset.range(100).apply(
@@ -283,7 +281,7 @@
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testMultipleTags(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@@ -300,15 +298,15 @@
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(i + 1))
+ self.evaluate(summary_t), "record_latency", float(i + 1))
self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency_2", float(i + 1))
+ self.evaluate(summary_t), "record_latency_2", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 100.0)
self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency_2", 100.0)
+ self.evaluate(summary_t), "record_latency_2", 100.0)
def testRepeatedTags(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
@@ -325,9 +323,9 @@
for i in range(100):
self.assertEqual(i, self.evaluate(next_element))
self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(2 * (i + 1)))
+ self.evaluate(summary_t), "record_latency", float(2 * (i + 1)))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 200.0)
@@ -342,13 +340,13 @@
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
- sess.run([iterator_0.initializer, iterator_1.initializer])
+ self.evaluate([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element))
self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(2 * (i + 1)))
+ self.evaluate(summary_t), "record_latency", float(2 * (i + 1)))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "record_latency", 200.0)
@@ -366,19 +364,19 @@
summary_t = aggregator.get_summary()
with self.test_session() as sess:
- sess.run([iterator_0.initializer, iterator_1.initializer])
+ self.evaluate([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element))
self._assertSummaryHasCount(
- sess.run(summary_t), "dataset1_record_latency", float(i + 1))
+ self.evaluate(summary_t), "dataset1_record_latency", float(i + 1))
self._assertSummaryHasCount(
- sess.run(summary_t), "dataset2_record_latency", float(i + 1))
+ self.evaluate(summary_t), "dataset2_record_latency", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self._assertSummaryHasCount(
- sess.run(summary_t), "dataset1_record_latency", 100.0)
+ self.evaluate(summary_t), "dataset1_record_latency", 100.0)
self._assertSummaryHasCount(
- sess.run(summary_t), "dataset2_record_latency", 100.0)
+ self.evaluate(summary_t), "dataset2_record_latency", 100.0)
@parameterized.named_parameters(
@@ -427,18 +425,19 @@
with self.test_session() as sess:
self.evaluate(iterator.initializer)
for _ in range(num_output):
- sess.run(next_element)
+ self.evaluate(next_element)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self._assertSummaryHasCount(
- sess.run(summary_t), "record_stats_features", total_records)
+ self.evaluate(summary_t), "record_stats_features", total_records)
self._assertSummaryHasCount(
- sess.run(summary_t), "record_stats_feature-values", total_records)
+ self.evaluate(summary_t), "record_stats_feature-values",
+ total_records)
self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats_features", total_records * 4)
+ self.evaluate(summary_t), "record_stats_features", total_records * 4)
self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats_feature-values",
+ self.evaluate(summary_t), "record_stats_feature-values",
self._sum_keywords(1) * num_epochs + 3 * total_records)
diff --git a/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py b/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
index 755294a..cb94bb4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
@@ -17,19 +17,16 @@
from __future__ import division
from __future__ import print_function
-import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -52,7 +49,7 @@
for i in range(4):
self.assertEqual(i, self.evaluate(next_elem))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_elem)
+ self.evaluate(next_elem)
def testUnbatchScalarDataset(self):
data = tuple([math_ops.range(10) for _ in range(3)])
@@ -71,7 +68,7 @@
self.assertEqual((i,) * 3, self.evaluate(op))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
+ self.evaluate(op)
def testUnbatchDatasetWithStrings(self):
data = tuple([math_ops.range(10) for _ in range(3)])
@@ -91,7 +88,7 @@
self.assertEqual((i, compat.as_bytes(str(i)), i), self.evaluate(op))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
+ self.evaluate(op)
def testUnbatchDatasetWithSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
@@ -112,7 +109,7 @@
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testUnbatchDatasetWithDenseAndSparseTensor(self):
st = sparse_tensor.SparseTensorValue(
@@ -134,7 +131,7 @@
self.assertEqual([i], st_row.values)
self.assertEqual([10], st_row.dense_shape)
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testUnbatchSingleElementTupleDataset(self):
data = tuple([(math_ops.range(10),) for _ in range(3)])
@@ -153,7 +150,7 @@
self.assertEqual(((i,),) * 3, self.evaluate(op))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
+ self.evaluate(op)
def testUnbatchMultiElementTupleDataset(self):
data = tuple([(math_ops.range(10 * i, 10 * i + 10),
@@ -171,10 +168,10 @@
with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
- sess.run(op))
+ self.evaluate(op))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
+ self.evaluate(op)
def testUnbatchEmpty(self):
data = dataset_ops.Dataset.from_tensors(
@@ -186,7 +183,7 @@
with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testUnbatchStaticShapeMismatch(self):
data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
@@ -211,7 +208,7 @@
ph2: np.arange(8).astype(np.int32)
})
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
+ self.evaluate(next_element)
# No 0th dimension (i.e. scalar value) for one component.
sess.run(
@@ -221,79 +218,7 @@
ph2: 7
})
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
-
-class UnbatchBenchmark(test.Benchmark):
-
- def benchmarkNativeUnbatch(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.apply(batching.unbatch())
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (native) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_native_batch_size_%d" %
- batch_size)
-
- # Include a benchmark of the previous `unbatch()` implementation that uses
- # a composition of more primitive ops. Eventually we'd hope to generate code
- # that is as good in both cases.
- def benchmarkOldUnbatchImplementation(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (unfused) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
- batch_size)
+ self.evaluate(next_element)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/unique_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_test.py
index 4b14a7e..91f4bc8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/unique_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/unique_test.py
@@ -55,7 +55,7 @@
element = compat.as_bytes(element)
self.assertAllEqual(element, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testSimpleInt(self):
for dtype in [dtypes.int32, dtypes.int64]:
diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD
index 170fda9..f954485 100644
--- a/tensorflow/python/data/experimental/ops/BUILD
+++ b/tensorflow/python/data/experimental/ops/BUILD
@@ -165,7 +165,7 @@
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:structure",
],
)
@@ -189,6 +189,28 @@
)
py_library(
+ name = "map_defun",
+ srcs = ["map_defun.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
+py_library(
+ name = "matching_files",
+ srcs = ["matching_files.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
+py_library(
name = "optimization",
srcs = ["optimization.py"],
srcs_version = "PY2AND3",
@@ -218,17 +240,6 @@
)
py_library(
- name = "map_defun",
- srcs = ["map_defun.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:tensor_shape",
- ],
-)
-
-py_library(
name = "resampling",
srcs = ["resampling.py"],
srcs_version = "PY2AND3",
@@ -303,6 +314,18 @@
srcs_version = "PY2AND3",
deps = [
":stats_aggregator",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/util:options",
+ ],
+)
+
+py_library(
+ name = "threading_options",
+ srcs = ["threading_options.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/util:options",
],
)
@@ -313,9 +336,8 @@
deps = [
"//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
"//tensorflow/python/eager:context",
],
)
@@ -386,6 +408,7 @@
":indexed_dataset_ops",
":interleave_ops",
":map_defun",
+ ":matching_files",
":optimization",
":prefetching_ops",
":readers",
diff --git a/tensorflow/python/data/experimental/ops/enumerate_ops.py b/tensorflow/python/data/experimental/ops/enumerate_ops.py
index a1af98f..04d875c 100644
--- a/tensorflow/python/data/experimental/ops/enumerate_ops.py
+++ b/tensorflow/python/data/experimental/ops/enumerate_ops.py
@@ -26,9 +26,9 @@
@tf_export("data.experimental.enumerate_dataset")
def enumerate_dataset(start=0):
- """A transformation that enumerate the elements of a dataset.
+ """A transformation that enumerates the elements of a dataset.
- It is Similar to python's `enumerate`.
+ It is similar to python's `enumerate`.
For example:
```python
@@ -44,8 +44,8 @@
```
Args:
- start: A `tf.int64` scalar `tf.Tensor`, representing the start
- value for enumeration.
+ start: A `tf.int64` scalar `tf.Tensor`, representing the start value for
+ enumeration.
Returns:
A `Dataset` transformation function, which can be passed to
diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py
index 80ca710..db10ea3 100644
--- a/tensorflow/python/data/experimental/ops/grouping.py
+++ b/tensorflow/python/data/experimental/ops/grouping.py
@@ -21,6 +21,7 @@
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -448,7 +449,10 @@
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping defun for reduce_func."""
- nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
+ nested_dataset = dataset_ops.DatasetStructure(
+ structure.Structure._from_legacy_structure( # pylint: disable=protected-access
+ input_dataset.output_types, input_dataset.output_shapes,
+ input_dataset.output_classes))
wrapped_func = dataset_ops.StructuredFunctionWrapper(
reduce_func,
self._transformation_name(),
@@ -456,11 +460,13 @@
input_shapes=(tensor_shape.scalar(), nested_dataset),
input_types=(dtypes.int64, nested_dataset))
if not isinstance(
- wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
+ wrapped_func.output_structure, dataset_ops.DatasetStructure):
raise TypeError("`reduce_func` must return a `Dataset` object.")
- self._output_classes = wrapped_func.output_classes.output_classes
- self._output_types = wrapped_func.output_types.output_types
- self._output_shapes = wrapped_func.output_shapes.output_shapes
+ # pylint: disable=protected-access
+ element_structure = wrapped_func.output_structure._element_structure
+ self._output_classes = element_structure._to_legacy_output_classes()
+ self._output_types = element_structure._to_legacy_output_types()
+ self._output_shapes = element_structure._to_legacy_output_shapes()
self._reduce_func = wrapped_func.function
@property
diff --git a/tensorflow/python/data/experimental/ops/matching_files.py b/tensorflow/python/data/experimental/ops/matching_files.py
new file mode 100644
index 0000000..8398f86
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/matching_files.py
@@ -0,0 +1,51 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for matching input filenames."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+
+
+class MatchingFilesDataset(dataset_ops.DatasetSource):
+ """A `Dataset` that list the files according to the input patterns."""
+
+ def __init__(self, patterns):
+ super(MatchingFilesDataset, self).__init__()
+ self._patterns = ops.convert_to_tensor(
+ patterns, dtype=dtypes.string, name="patterns")
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_matching_files_dataset(self._patterns)
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.string
+
+
diff --git a/tensorflow/python/data/experimental/ops/parsing_ops.py b/tensorflow/python/data/experimental/ops/parsing_ops.py
index 6615b90..a63eb8c 100644
--- a/tensorflow/python/data/experimental/ops/parsing_ops.py
+++ b/tensorflow/python/data/experimental/ops/parsing_ops.py
@@ -138,10 +138,10 @@
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
- if any([
+ if any(
isinstance(feature, parsing_ops.SparseFeature)
for _, feature in features.items()
- ]):
+ ):
# pylint: disable=protected-access
# pylint: disable=g-long-lambda
out_dataset = out_dataset.map(
diff --git a/tensorflow/python/data/experimental/ops/stats_options.py b/tensorflow/python/data/experimental/ops/stats_options.py
index c088d3d..cd7fdcb 100644
--- a/tensorflow/python/data/experimental/ops/stats_options.py
+++ b/tensorflow/python/data/experimental/ops/stats_options.py
@@ -20,11 +20,12 @@
from __future__ import print_function
from tensorflow.python.data.experimental.ops import stats_aggregator
+from tensorflow.python.data.util import options
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.StatsOptions")
-class StatsOptions(object):
+class StatsOptions(options.OptionsBase):
"""Represents options for collecting dataset stats using `StatsAggregator`.
To apply `StatsOptions` with a `tf.data.Dataset` object, use the following
@@ -52,52 +53,29 @@
```
"""
- for _name, _ty, _default, _docstring in [
- ("aggregator", stats_aggregator.StatsAggregator, None,
- "Associate the given statistics options with the dataset pipeline."),
- ("prefix", str, "",
- "Prefix to prepend all statistics recorded for the input `dataset` with."
- ),
- ("counter_prefix", str, "",
- "Prefix for the statistics recorded as counter."),
- ("latency_all_edges", bool, True,
- "Whether to add latency measurements on all edges."),
- ]:
+ aggregator = options.create_option(
+ name="aggregator",
+ ty=stats_aggregator.StatsAggregator,
+ docstring=
+ "Associates the given statistics aggregator with the dataset pipeline.")
- def _make_getter(name): # pylint: disable=no-self-argument
+ prefix = options.create_option(
+ name="prefix",
+ ty=str,
+ docstring=
+ "Prefix to prepend all statistics recorded for the input `dataset` with.",
+ default="")
- def getter(self):
- return getattr(self, "_" + name)
+ counter_prefix = options.create_option(
+ name="counter_prefix",
+ ty=str,
+ docstring=
+ "Prefix for the statistics recorded as counter.",
+ default="")
- return getter
-
- def _make_setter(name, ty): # pylint: disable=no-self-argument
-
- def setter(self, value):
- if not isinstance(value, ty):
- raise TypeError(
- "Attempting to set the option %s to incompatible value: %r when "
- "it expects %r" % (name, value, ty))
- setattr(self, "_" + name, value)
-
- return setter
-
- vars()["_" + _name] = _default
- vars()[_name] = property(
- _make_getter(_name), _make_setter(_name, _ty), _default, _docstring)
-
- def __init__(self, aggregator=None):
- if aggregator:
- self.aggregator = aggregator
-
- def __eq__(self, other):
- if isinstance(other, self.__class__):
- return self.__dict__ == other.__dict__
- else:
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __str__(self):
- return str(self.__dict__)
+ latency_all_edges = options.create_option(
+ name="latency_all_edges",
+ ty=bool,
+ docstring=
+ "Whether to add latency measurements on all edges.",
+ default=True)
diff --git a/tensorflow/python/data/experimental/ops/threading_options.py b/tensorflow/python/data/experimental/ops/threading_options.py
new file mode 100644
index 0000000..98df371
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/threading_options.py
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for controlling threading in `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.data.util import options
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.ThreadingOptions")
+class ThreadingOptions(options.OptionsBase):
+ """Represents options for dataset threading.
+
+ To apply `ThreadingOptions` to a `dataset` object, use the following pattern:
+
+ ```python
+ options = dataset_ops.Options()
+ options.experimental_threading = tf.data.experimental.ThreadingOptions()
+ options.experimental_threading.private_threadpool_size = 10
+ dataset = dataset.with_options(options)
+ ```
+ """
+
+ max_intra_op_parallelism = options.create_option(
+ name="max_intra_op_parallelism",
+ ty=int,
+ docstring=
+ "If set, it overrides the maximum degree of intra-op parallelism.")
+
+ private_threadpool_size = options.create_option(
+ name="private_threadpool_size",
+ ty=int,
+ docstring=
+ "If set, the dataset will use a private threadpool of the given size.",
+ default=None)
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 21eed2b..0867471 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -10,122 +10,101 @@
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
tf_py_test(
- name = "batch_dataset_op_test",
+ name = "batch_test",
size = "small",
- srcs = ["batch_dataset_op_test.py"],
+ srcs = ["batch_test.py"],
additional_deps = [
":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
+ "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python:sparse_tensor",
],
)
tf_py_test(
- name = "cache_dataset_op_test",
+ name = "cache_test",
size = "small",
- srcs = ["cache_dataset_op_test.py"],
+ srcs = ["cache_test.py"],
additional_deps = [
":test_base",
"//third_party/py/numpy",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
-)
-
-tf_py_test(
- name = "concatenate_dataset_op_test",
- size = "small",
- srcs = ["concatenate_dataset_op_test.py"],
- additional_deps = [
- ":test_base",
- "//third_party/py/numpy",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-tf_py_test(
- name = "dataset_constructor_op_test",
- size = "small",
- srcs = ["dataset_constructor_op_test.py"],
- additional_deps = [
- ":test_base",
- "//third_party/py/numpy",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
- tags = [
- "manual",
- "nomac", # b/62040583
+ "//tensorflow/python:variables",
],
)
tf_py_test(
- name = "dataset_from_generator_op_test",
- size = "medium",
- srcs = ["dataset_from_generator_op_test.py"],
+ name = "concatenate_test",
+ size = "small",
+ srcs = ["concatenate_test.py"],
additional_deps = [
":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/util:nest",
],
)
tf_py_test(
- name = "dataset_ops_test",
+ name = "dataset_checkpoint_test",
size = "small",
- srcs = ["dataset_ops_test.py"],
+ srcs = ["dataset_checkpoint_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:variables",
+ ],
+)
+
+tf_py_test(
+ name = "dataset_test",
+ size = "small",
+ srcs = ["dataset_test.py"],
additional_deps = [
":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:sparse_tensor",
],
)
tf_py_test(
- name = "filter_dataset_op_test",
+ name = "filter_test",
size = "small",
- srcs = ["filter_dataset_op_test.py"],
+ srcs = ["filter_test.py"],
additional_deps = [
":test_base",
"//third_party/py/numpy",
@@ -141,12 +120,36 @@
)
tf_py_test(
- name = "flat_map_dataset_op_test",
+ name = "fixed_length_record_dataset_test",
size = "small",
- srcs = ["flat_map_dataset_op_test.py"],
+ srcs = ["fixed_length_record_dataset_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+tf_py_test(
+ name = "flat_map_test",
+ size = "medium",
+ srcs = ["flat_map_test.py"],
additional_deps = [
":test_base",
"//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:session",
@@ -159,58 +162,157 @@
)
tf_py_test(
- name = "list_files_dataset_op_test",
- size = "small",
- srcs = ["list_files_dataset_op_test.py"],
+ name = "from_generator_test",
+ size = "medium",
+ srcs = ["from_generator_test.py"],
additional_deps = [
":test_base",
- "//tensorflow/python:array_ops",
+ "//third_party/py/numpy",
+ "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
- name = "inputs_test",
- size = "small",
- srcs = ["inputs_test.py"],
- additional_deps = [
- ":test_base",
- "@absl_py//absl/testing:parameterized",
- "//third_party/py/numpy",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
- name = "interleave_dataset_op_test",
- size = "small",
- srcs = ["interleave_dataset_op_test.py"],
- additional_deps = [
- ":test_base",
- "@absl_py//absl/testing:parameterized",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
"//tensorflow/python:session",
+ ],
+)
+
+tf_py_test(
+ name = "from_sparse_tensor_slices_test",
+ size = "small",
+ srcs = ["from_sparse_tensor_slices_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_py_test(
+ name = "from_tensors_test",
+ size = "small",
+ srcs = ["from_tensors_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+ tags = [
+ "nomac", # b/62040583
+ ],
+)
+
+tf_py_test(
+ name = "from_tensor_slices_test",
+ size = "small",
+ srcs = ["from_tensor_slices_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+tf_py_test(
+ name = "interleave_test",
+ size = "medium",
+ srcs = ["interleave_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
+ ],
+)
+
+tf_py_test(
+ name = "iterator_checkpoint_test",
+ size = "medium",
+ srcs = ["iterator_checkpoint_test.py"],
+ additional_deps = [
+ ":test_base",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:util",
+ "//tensorflow/python:checkpoint_management",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ ],
+ grpc_enabled = True,
+)
+
+tf_py_test(
+ name = "iterator_cluster_test",
+ size = "small",
+ srcs = ["iterator_cluster_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:string_ops",
+ ],
+ grpc_enabled = True,
+ tags = [
+ "no_oss", # Test flaky due to port collisions.
+ "no_windows",
],
)
cuda_py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
+ name = "iterator_test",
+ size = "medium",
+ srcs = ["iterator_test.py"],
additional_deps = [
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
@@ -249,41 +351,30 @@
)
tf_py_test(
- name = "iterator_ops_cluster_test",
+ name = "list_files_test",
size = "small",
- srcs = ["iterator_ops_cluster_test.py"],
+ srcs = ["list_files_test.py"],
additional_deps = [
- "//tensorflow/core:protos_all_py",
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:session",
+ "//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:lookup_ops",
- ],
- grpc_enabled = True,
- tags = [
- "no_oss", # Test flaky due to port collisions.
- "no_windows",
],
)
tf_py_test(
- name = "map_dataset_op_test",
- size = "small",
- srcs = ["map_dataset_op_test.py"],
+ name = "map_test",
+ size = "medium",
+ srcs = ["map_test.py"],
additional_deps = [
":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -297,27 +388,12 @@
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_util",
"//tensorflow/python:variable_scope",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
- name = "matching_files_dataset_op_test",
- size = "small",
- srcs = ["matching_files_dataset_op_test.py"],
- additional_deps = [
- ":test_base",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -345,9 +421,9 @@
)
cuda_py_test(
- name = "optional_ops_test",
+ name = "optional_test",
size = "small",
- srcs = ["optional_ops_test.py"],
+ srcs = ["optional_test.py"],
additional_deps = [
":test_base",
"@absl_py//absl/testing:parameterized",
@@ -366,9 +442,30 @@
)
tf_py_test(
- name = "prefetch_dataset_op_test",
+ name = "padded_batch_test",
size = "small",
- srcs = ["prefetch_dataset_op_test.py"],
+ srcs = ["padded_batch_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ ],
+)
+
+tf_py_test(
+ name = "prefetch_test",
+ size = "small",
+ srcs = ["prefetch_test.py"],
additional_deps = [
":test_base",
"@absl_py//absl/testing:parameterized",
@@ -377,59 +474,26 @@
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
- name = "range_dataset_op_test",
+ name = "range_test",
size = "small",
- srcs = ["range_dataset_op_test.py"],
+ srcs = ["range_test.py"],
additional_deps = [
":test_base",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
-)
-
-tf_py_test(
- name = "reader_dataset_ops_test",
- size = "small",
- srcs = ["reader_dataset_ops_test.py"],
- additional_deps = [
- ":test_base",
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python:framework_test_lib",
],
)
tf_py_test(
- name = "reduce_dataset_op_test",
+ name = "reduce_test",
size = "small",
- srcs = ["reduce_dataset_op_test.py"],
+ srcs = ["reduce_test.py"],
additional_deps = [
":test_base",
"@absl_py//absl/testing:parameterized",
@@ -437,7 +501,6 @@
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
@@ -445,9 +508,9 @@
)
tf_py_test(
- name = "sequence_dataset_op_test",
+ name = "repeat_test",
size = "small",
- srcs = ["sequence_dataset_op_test.py"],
+ srcs = ["repeat_test.py"],
additional_deps = [
":test_base",
"//third_party/py/numpy",
@@ -460,9 +523,9 @@
)
tf_py_test(
- name = "shard_dataset_op_test",
+ name = "shard_test",
size = "small",
- srcs = ["shard_dataset_op_test.py"],
+ srcs = ["shard_test.py"],
additional_deps = [
":test_base",
"//tensorflow/python:client_testlib",
@@ -472,9 +535,9 @@
)
tf_py_test(
- name = "shuffle_dataset_op_test",
+ name = "shuffle_test",
size = "small",
- srcs = ["shuffle_dataset_op_test.py"],
+ srcs = ["shuffle_test.py"],
additional_deps = [
":test_base",
"@absl_py//absl/testing:parameterized",
@@ -491,21 +554,91 @@
],
)
-py_library(
- name = "test_base",
- srcs = ["test_base.py"],
- deps = [
+tf_py_test(
+ name = "skip_test",
+ size = "small",
+ srcs = ["skip_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
- name = "window_dataset_op_test",
+ name = "take_test",
size = "small",
- srcs = ["window_dataset_op_test.py"],
+ srcs = ["take_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+tf_py_test(
+ name = "text_line_dataset_test",
+ size = "small",
+ srcs = ["text_line_dataset_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:util",
+ ],
+)
+
+tf_py_test(
+ name = "tf_record_dataset_test",
+ size = "small",
+ srcs = ["tf_record_dataset_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
+ name = "test_base",
+ srcs = ["test_base.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+tf_py_test(
+ name = "window_test",
+ size = "medium",
+ srcs = ["window_test.py"],
additional_deps = [
":test_base",
"@absl_py//absl/testing:parameterized",
@@ -521,9 +654,9 @@
)
tf_py_test(
- name = "zip_dataset_op_test",
+ name = "zip_test",
size = "small",
- srcs = ["zip_dataset_op_test.py"],
+ srcs = ["zip_test.py"],
additional_deps = [
":test_base",
"//third_party/py/numpy",
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
deleted file mode 100644
index 10a0427..0000000
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ /dev/null
@@ -1,515 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @parameterized.named_parameters(
- ('even', 28, 14, False),
- ('uneven_with_remainder', 28, 15, False),
- ('uneven_without_remainder', 28, 15, True),
- ('empty', 0, 14, False),
- )
- def testBatchDataset(self, count, batch_size, drop_remainder):
- """Tests the batch dataset logic for various input configurations.
-
- Args:
- count: the number of input elements
- batch_size: the batch size
- drop_remainder: whether a smaller batch size should be produced if batch
- size does not divide number of inputs evenly
- """
-
- # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
- # RepeatDataset(count) -> BatchDataset(batch_size).
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- count_t = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size_t = array_ops.placeholder(dtypes.int64, shape=[])
- drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(count).batch(batch_size,
- drop_remainder).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- if drop_remainder:
- dim0 = batch_size
- else:
- dim0 = None
- self.assertEqual([[dim0] + list(c.shape[1:]) for c in components],
- [t.shape.as_list() for t in get_next])
-
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- count_t: count,
- batch_size_t: batch_size,
- drop_remainder_t: drop_remainder
- })
- num_full_batches = (count * 7) // batch_size
- for i in range(num_full_batches):
- result = self.evaluate(get_next)
- for component, result_component in zip(components, result):
- for j in range(batch_size):
- self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
- result_component[j])
- if not drop_remainder and (count * 7) % batch_size > 0:
- result = self.evaluate(get_next)
- for component, result_component in zip(components, result):
- for j in range((count * 7) % batch_size):
- self.assertAllEqual(
- component[(num_full_batches * batch_size + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testBatchDatasetInvalidBatchSize(self):
- iterator = (dataset_ops.Dataset.range(10).batch(0).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def testBatchSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(
- 5).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(2):
- actual = self.evaluate(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testBatchSparseWithDifferentDenseShapes(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=array_ops.expand_dims(
- math_ops.range(i, dtype=dtypes.int64), 1),
- values=array_ops.fill([math_ops.to_int32(i)], i),
- dense_shape=[i])
-
- iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(
- 5).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(2):
- actual = self.evaluate(get_next)
- expected_indices = []
- expected_values = []
- for j in range(5):
- for k in range(i * 5 + j):
- expected_indices.append([j, k])
- expected_values.append(i * 5 + j)
- expected = sparse_tensor.SparseTensorValue(
- indices=expected_indices,
- values=expected_values,
- dense_shape=[5, (i + 1) * 5 - 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testNestedBatchSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch(
- 2).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- actual = self.evaluate(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0],
- [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]],
- values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
- dense_shape=[2, 5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testBatchShapeError(self):
-
- def generator():
- yield [1.0, 2.0, 3.0]
- yield [4.0, 5.0, 6.0]
- yield [7.0, 8.0, 9.0, 10.0]
-
- iterator = (
- dataset_ops.Dataset.from_generator(
- generator, dtypes.float32, output_shapes=[None]).batch(3)
- .make_initializable_iterator())
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(iterator.initializer)
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- r'Cannot batch tensors with different shapes in component 0. '
- r'First element had shape \[3\] and element 2 had shape \[4\].'):
- sess.run(next_element)
-
-
-def _random_seq_lens(count):
- return np.random.randint(20, size=(count,)).astype(np.int32)
-
-
-class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @parameterized.named_parameters(
- ('default_padding', _random_seq_lens(32), 4, [-1], False),
- ('constant_padding', _random_seq_lens(32), 4, [25], False),
- ('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False),
- ('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True),
- )
- def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes,
- drop_remainder):
- """Tests the padded batch dataset logic for various input configurations.
-
- Args:
- seq_lens: the input sequence lengths
- batch_size: the batch size
- padded_shapes: the padded shapes to use
- drop_remainder: whether a smaller batch size should be produced if batch
- size does not divide number of inputs evenly
- """
-
- seq_lens_t = array_ops.placeholder(dtypes.int32, shape=[None])
- batch_size_t = array_ops.placeholder(dtypes.int64, shape=[])
- padded_shapes_t = array_ops.placeholder(dtypes.int64, shape=[1])
- drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(seq_lens_t)
- .map(lambda x: array_ops.fill([x], x)).padded_batch(
- batch_size=batch_size_t,
- drop_remainder=drop_remainder_t,
- padded_shapes=padded_shapes_t).make_initializable_iterator())
-
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- seq_lens_t: seq_lens,
- batch_size_t: batch_size,
- padded_shapes_t: padded_shapes,
- drop_remainder_t: drop_remainder,
- })
-
- num_full_batches = len(seq_lens) // batch_size
-
- for i in range(num_full_batches):
- result = self.evaluate(get_next)
- padded_len = padded_shapes[0]
- if padded_len is None or padded_len == -1:
- padded_len = np.max(result) if result.size > 0 else 0
- self.assertEqual((batch_size, padded_len), result.shape)
- for j in range(batch_size):
- seq_len = seq_lens[(i * batch_size) + j]
- self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
- self.assertAllEqual(result[j, seq_len:],
- [0] * (padded_len - seq_len))
-
- if not drop_remainder and len(seq_lens) % batch_size > 0:
- result = self.evaluate(get_next)
- padded_len = np.max(result) if result.size > 0 else 0
- self.assertEqual((len(seq_lens) % batch_size, padded_len),
- result.shape)
- for j in range(len(seq_lens) % batch_size):
- seq_len = seq_lens[num_full_batches * batch_size + j]
- self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
- self.assertAllEqual(result[j, seq_len:],
- [0] * (padded_len - seq_len))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testPaddedBatchShortPadding(self):
- iterator = (
- dataset_ops.Dataset.from_tensor_slices([6, 5, 5, 5, 5])
- .map(lambda x: array_ops.fill([x], x)).padded_batch(
- batch_size=4, padded_shapes=[5]).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.DataLossError):
- sess.run(get_next)
-
- def testPaddedBatchEmptyTensors(self):
- iterator = (
- dataset_ops.Dataset.from_tensor_slices([0, 0, 0, 0])
- .map(lambda x: array_ops.fill([x], x)).padded_batch(
- batch_size=4, padded_shapes=[-1]).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- result = self.evaluate(get_next)
- self.assertAllEqual([[], [], [], []], result)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testPaddedBatchDatasetNonDefaultPadding(self):
- seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
- padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
-
- def fill_tuple(x):
- filled = array_ops.fill([x], x)
- return (filled, string_ops.as_string(filled))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
- .padded_batch(
- 4,
- padded_shapes=(padded_shape, padded_shape),
- padding_values=(-1, '<end>')).make_initializable_iterator())
-
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Test with random sequence lengths, and max padding.
- random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
- sess.run(
- init_op, feed_dict={
- padded_shape: [-1],
- seq_lens: random_seq_lens
- })
- for i in range(8):
- result = self.evaluate(get_next)
- padded_len = np.max(result[0])
- self.assertEqual((4, padded_len), result[0].shape)
- self.assertEqual((4, padded_len), result[1].shape)
- for j in range(4):
- seq_len = random_seq_lens[(i * 4) + j]
- self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
- self.assertAllEqual(result[0][j, seq_len:],
- [-1] * (padded_len - seq_len))
- self.assertAllEqual(result[1][j, :seq_len],
- [compat.as_bytes(str(seq_len))] * seq_len)
- self.assertAllEqual(result[1][j, seq_len:],
- [b'<end>'] * (padded_len - seq_len))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testPaddedBatchDatasetUnicode(self):
- # See GitHub issue 16149
- def generator():
- data = [[u'Простой', u'тест', u'юникода'],
- [u'никогда', u'не', u'бывает', u'простым']]
-
- for seq in data:
- yield seq, [0, 1, 2, 3]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator, (dtypes.string, dtypes.int32),
- (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
- padded_dataset = dataset.padded_batch(
- 2, padded_shapes=([None], [None]), padding_values=('', 0))
- with self.cached_session() as sess:
- next_element = padded_dataset.make_one_shot_iterator().get_next()
- sess.run(next_element)
-
- def testPaddedBatchDatasetShapeSpecifications(self):
- int_placeholder = array_ops.placeholder(dtypes.int32)
- float_placeholder = array_ops.placeholder(dtypes.float32)
- string_placeholder = array_ops.placeholder(dtypes.string)
- input_dataset = dataset_ops.Dataset.from_tensors(
- (int_placeholder, float_placeholder, string_placeholder))
-
- # Test different ways of specifying the `padded_shapes` argument.
- dynamic_padding_from_tensor_shapes = input_dataset.padded_batch(
- 32,
- padded_shapes=(tensor_shape.TensorShape([None]),
- tensor_shape.TensorShape([None, None]),
- tensor_shape.TensorShape([37])))
- dynamic_padding_from_lists = input_dataset.padded_batch(
- 32, padded_shapes=([None], [None, None], [37]))
- dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch(
- 32, padded_shapes=([-1], [-1, -1], [37]))
- dynamic_padding_from_tensors = input_dataset.padded_batch(
- 32,
- padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64),
- constant_op.constant([-1, -1], dtype=dtypes.int64),
- constant_op.constant([37], dtype=dtypes.int64)))
-
- for dataset in [
- dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
- dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
- ]:
- self.assertEqual([None, None], dataset.output_shapes[0].as_list())
- self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
- self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
-
- def testPaddedBatchSparseError(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
-
- with self.assertRaises(TypeError):
- _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
-
- def testPaddedBatchShapeError(self):
- with self.assertRaisesRegexp(
- ValueError, r'The padded shape \(1,\) is not compatible with the '
- r'corresponding input component shape \(\).'):
- _ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1])
-
- with self.assertRaisesRegexp(
- ValueError, r'The padded shape \(1,\) is not compatible with the '
- r'corresponding input component shape \(3,\).'):
- _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
- 5, padded_shapes=[1])
-
- with self.assertRaisesRegexp(
- ValueError, r'Padded shape .* must be a 1-D tensor '
- r'of tf.int64 values, but its shape was \(2, 2\).'):
- _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
- 5, padded_shapes=[[1, 1], [1, 1]])
-
- with self.assertRaisesRegexp(
- TypeError, r'Padded shape .* must be a 1-D tensor '
- r'of tf.int64 values, but its element type was float32.'):
- _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
- 5, padded_shapes=constant_op.constant([1., 2., 3.]))
-
- with self.assertRaisesRegexp(
- ValueError, r'The padded shape \(1,\) is not compatible with the '
- r'corresponding input component shape \(\).'):
- shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64)
- _ = dataset_ops.Dataset.range(10).padded_batch(
- 5, padded_shapes=shape_as_tensor)
-
- with self.assertRaisesRegexp(
- ValueError,
- r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '
- r'corresponding input component shape \(\).'):
- shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2])
- _ = dataset_ops.Dataset.range(10).padded_batch(
- 5, padded_shapes=shape_as_tensor)
-
-
-class BatchDatasetBenchmark(test.Benchmark):
-
- def benchmarkBatchSparse(self):
- non_zeros_per_row_values = [0, 1, 5, 10, 100]
- batch_size_values = [1, 32, 64, 128, 1024]
-
- sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64)
- batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[])
-
- dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat(
- ).batch(batch_size_placeholder)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- for non_zeros_per_row in non_zeros_per_row_values:
-
- sparse_value = sparse_tensor.SparseTensorValue(
- indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
- values=np.arange(non_zeros_per_row, dtype=np.int64),
- dense_shape=[1000])
-
- for batch_size in batch_size_values:
-
- with session.Session() as sess:
- sess.run(iterator.initializer, feed_dict={
- sparse_placeholder: sparse_value,
- batch_size_placeholder: batch_size})
- # Run five steps to warm up the session caches before taking the
- # first measurement.
- for _ in range(5):
- sess.run(next_element.indices.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.indices.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100.0
-
- print('Batch sparse dataset non-zeros per row: %d batch_size: %d '
- 'wall time: %f'
- % (non_zeros_per_row, batch_size, median_wall_time))
- self.report_benchmark(
- iters=10000, wall_time=median_wall_time,
- name='benchmark_batch_sparse_dataset_nnz_%d_batch_size_%d' % (
- non_zeros_per_row, batch_size))
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/batch_test.py b/tensorflow/python/data/kernel_tests/batch_test.py
new file mode 100644
index 0000000..5b035e5
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/batch_test.py
@@ -0,0 +1,173 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.batch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class BatchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('even', 28, 14, False),
+ ('uneven_with_remainder', 28, 15, False),
+ ('uneven_without_remainder', 28, 15, True),
+ ('empty', 0, 14, False),
+ )
+ def testBatchDataset(self, count, batch_size, drop_remainder):
+ """Tests the batch dataset logic for various input configurations.
+
+ Args:
+ count: the number of input elements
+ batch_size: the batch size
+ drop_remainder: whether a smaller batch size should be produced if batch
+ size does not divide number of inputs evenly
+ """
+
+ # The pipeline is TensorSliceDataset -> MapDataset(square_3) ->
+ # RepeatDataset(count) -> BatchDataset(batch_size).
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn).repeat(count).batch(batch_size, drop_remainder)
+ get_next = self.getNext(dataset)
+
+ if drop_remainder:
+ dim0 = batch_size
+ else:
+ dim0 = None
+ self.assertEqual(
+ [ts.as_list() for ts in nest.flatten(dataset.output_shapes)],
+ [[dim0] + list(c.shape[1:]) for c in components])
+
+ num_full_batches = (count * 7) // batch_size
+ for i in range(num_full_batches):
+ result = self.evaluate(get_next())
+ for component, result_component in zip(components, result):
+ for j in range(batch_size):
+ self.assertAllEqual(component[(i * batch_size + j) % 7]**2,
+ result_component[j])
+ if not drop_remainder and (count * 7) % batch_size > 0:
+ result = self.evaluate(get_next())
+ for component, result_component in zip(components, result):
+ for j in range((count * 7) % batch_size):
+ self.assertAllEqual(
+ component[(num_full_batches * batch_size + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ result = self.evaluate(get_next())
+
+ def testBatchDatasetInvalidBatchSize(self):
+ dataset = (dataset_ops.Dataset.range(10).batch(0))
+ self.assertDatasetProduces(
+ dataset, expected_error=(errors.InvalidArgumentError, ''))
+
+ def testBatchSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ dataset = dataset_ops.Dataset.range(10).map(_sparse).batch(5)
+ expected_output = [
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
+ dense_shape=[5, 1]) for i in range(2)
+ ]
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testBatchSparseWithDifferentDenseShapes(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=array_ops.expand_dims(
+ math_ops.range(i, dtype=dtypes.int64), 1),
+ values=array_ops.fill([math_ops.to_int32(i)], i),
+ dense_shape=[i])
+
+ dataset = dataset_ops.Dataset.range(10).map(_sparse).batch(5)
+ expected_output = []
+ for i in range(2):
+ expected_indices = []
+ expected_outputs = []
+ for j in range(5):
+ for k in range(i * 5 + j):
+ expected_indices.append([j, k])
+ expected_outputs.append(i * 5 + j)
+ expected_output.append(
+ sparse_tensor.SparseTensorValue(
+ indices=expected_indices,
+ values=expected_outputs,
+ dense_shape=[5, (i + 1) * 5 - 1]))
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testNestedBatchSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ dataset = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch(2)
+ expected_output = [
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0],
+ [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]],
+ values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
+ dense_shape=[2, 5, 1])
+ ]
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testBatchShapeError(self):
+
+ def generator():
+ yield [1.0, 2.0, 3.0]
+ yield [4.0, 5.0, 6.0]
+ yield [7.0, 8.0, 9.0, 10.0]
+
+ dataset = (
+ dataset_ops.Dataset.from_generator(
+ generator, dtypes.float32, output_shapes=[None]).batch(3))
+ self.assertDatasetProduces(
+ dataset,
+ expected_error=(
+ errors.InvalidArgumentError,
+ r'Cannot batch tensors with different shapes in component 0. First '
+ r'element had shape \[3\] and element 2 had shape \[4\].'))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
deleted file mode 100644
index 1f35127..0000000
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ /dev/null
@@ -1,318 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from os import path
-import shutil
-import tempfile
-
-import numpy as np
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-
-class FileCacheDatasetTest(test_base.DatasetTestBase):
-
- def setUp(self):
- self.tmp_dir = tempfile.mkdtemp()
- self.cache_prefix = path.join(self.tmp_dir, "cache")
-
- def tearDown(self):
- if self.tmp_dir:
- shutil.rmtree(self.tmp_dir, ignore_errors=True)
-
- def testCacheDatasetPassthrough(self):
- components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
- np.array([9.0, 10.0, 11.0, 12.0]))
- count_placeholder = array_ops.placeholder_with_default(
- constant_op.constant(5, dtypes.int64), shape=[])
- filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
-
- repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
- .repeat(count_placeholder))
-
- cache_dataset = repeat_dataset.cache(filename_placeholder)
-
- self.assertEqual(
- tuple([c.shape[1:] for c in components]), cache_dataset.output_shapes)
-
- # Create initialization ops for iterators without and with
- # caching, respectively.
- iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types,
- cache_dataset.output_shapes)
- init_fifo_op = iterator.make_initializer(repeat_dataset)
- init_cache_op = iterator.make_initializer(cache_dataset)
-
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # First run without caching to collect the "ground truth".
- self.evaluate(init_fifo_op)
- elements = []
- for _ in range(20):
- elements.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Assert that the cached dataset has the same elements as the
- # "ground truth".
- sess.run(
- init_cache_op, feed_dict={filename_placeholder: self.cache_prefix})
- cached_elements = []
- for _ in range(20):
- cached_elements.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- self.assertAllEqual(elements, cached_elements)
-
- # Re-initialize with an empty upstream (to throw errors.OutOfRangeError
- # if we didn't use the cache).
- sess.run(
- init_cache_op,
- feed_dict={
- count_placeholder: 0,
- filename_placeholder: self.cache_prefix
- })
- replayed_elements = []
- for _ in range(20):
- replayed_elements.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- self.assertEqual(cached_elements, replayed_elements)
-
- # Re-initialize with an empty upstream and a missing cache file (should
- # throw errors.OutOfRangeError immediately).
- sess.run(
- init_cache_op,
- feed_dict={
- count_placeholder: 0,
- filename_placeholder: self.cache_prefix + "nonsense"
- })
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testConcurrentWriters(self):
- components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
- np.array([9.0, 10.0, 11.0, 12.0]))
- filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
-
- cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components)
- .cache(filename_placeholder))
- cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components)
- .cache(filename_placeholder))
-
- iterator1 = cache_dataset1.make_initializable_iterator()
- iterator2 = cache_dataset2.make_initializable_iterator()
- init_cache_op1 = iterator1.initializer
- init_cache_op2 = iterator2.initializer
-
- get_next1 = iterator1.get_next()
- get_next2 = iterator2.get_next()
-
- with self.cached_session() as sess:
- sess.run(
- init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
- sess.run(get_next1) # this should succeed
-
- sess.run(
- init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix})
- with self.assertRaises(errors.AlreadyExistsError):
- sess.run(get_next2)
-
- sess.run(get_next1) # this should continue to succeed
-
- def testConcurrentReaders(self):
- components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
- np.array([9.0, 10.0, 11.0, 12.0]))
- filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
-
- cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components)
- .cache(filename_placeholder))
- cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components)
- .cache(filename_placeholder))
-
- iterator1 = cache_dataset1.make_initializable_iterator()
- iterator2 = cache_dataset2.make_initializable_iterator()
- init_cache_op1 = iterator1.initializer
- init_cache_op2 = iterator2.initializer
-
- get_next1 = iterator1.get_next()
- get_next2 = iterator2.get_next()
-
- with self.cached_session() as sess:
- sess.run(
- init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
- elements = []
- for _ in range(4):
- elements.append(sess.run(get_next1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next1)
-
- # Re-initialize
- sess.run(
- init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix})
- sess.run(
- init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix})
-
- # Reading concurrently should succeed.
- elements_itr1 = []
- elements_itr2 = []
- elements_itr2.append(sess.run(get_next2))
- elements_itr1.append(sess.run(get_next1))
- elements_itr2.append(sess.run(get_next2))
- elements_itr1.append(sess.run(get_next1))
- # Intentionally reversing the order
- elements_itr1.append(sess.run(get_next1))
- elements_itr2.append(sess.run(get_next2))
- elements_itr1.append(sess.run(get_next1))
- elements_itr2.append(sess.run(get_next2))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next2)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next1)
-
- self.assertAllEqual(elements, elements_itr1)
- self.assertAllEqual(elements, elements_itr2)
-
-
-class MemoryCacheDatasetTest(test_base.DatasetTestBase):
-
- def testCacheDatasetPassthrough(self):
- with ops.device("cpu:0"):
- repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64))
- dataset = dataset_ops.Dataset.range(3).flat_map(
- lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count))
-
- cached_dataset = dataset.cache().repeat(2)
- uncached_dataset = dataset.repeat(2)
-
- # Needs to be initializable to capture the variable.
- cached_iterator = cached_dataset.make_initializable_iterator()
- cached_next = cached_iterator.get_next()
- uncached_iterator = uncached_dataset.make_initializable_iterator()
- uncached_next = uncached_iterator.get_next()
-
- with self.cached_session() as sess:
-
- self.evaluate(repeat_count.initializer)
- self.evaluate(cached_iterator.initializer)
- self.evaluate(uncached_iterator.initializer)
-
- for i in range(3):
- for _ in range(10):
- self.assertEqual(self.evaluate(cached_next), i)
- self.assertEqual(self.evaluate(uncached_next), i)
-
- sess.run(repeat_count.assign(0))
-
- # The uncached iterator should now be empty.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(uncached_next)
-
- # The cached iterator replays from cache.
- for i in range(3):
- for _ in range(10):
- self.assertEqual(self.evaluate(cached_next), i)
-
- # The cached iterator should now be empty.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(cached_next)
-
- def testEmptyCacheReading(self):
- components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
- np.array([9.0, 10.0, 11.0, 12.0]))
- count_placeholder = array_ops.placeholder_with_default(
- constant_op.constant(5, dtypes.int64), shape=[])
-
- repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
- .repeat(count_placeholder))
-
- cache_dataset = repeat_dataset.cache()
-
- # Create initialization ops for iterators without and with
- # caching, respectively.
- iterator = cache_dataset.make_initializable_iterator()
- init_cache_op = iterator.initializer
-
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Initialize with an empty upstream and a missing cache file (should
- # throw errors.OutOfRangeError immediately).
- sess.run(init_cache_op, feed_dict={count_placeholder: 0})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testConcurrentReaders(self):
- count_placeholder = array_ops.placeholder_with_default(
- constant_op.constant(5, dtypes.int64), shape=[])
- dataset = dataset_ops.Dataset.range(count_placeholder).cache()
- d1 = dataset.map(lambda x: x + 1)
- d2 = dataset.map(lambda x: x + 6)
-
- i1 = d1.make_initializable_iterator()
- i2 = d2.make_initializable_iterator()
-
- with self.cached_session() as sess:
- self.evaluate(i1.initializer)
-
- self.assertEqual(1, sess.run(i1.get_next()))
- self.assertEqual(2, sess.run(i1.get_next()))
- self.assertEqual(3, sess.run(i1.get_next()))
-
- sess.run(i2.initializer, feed_dict={count_placeholder: 3})
-
- self.assertEqual(6, sess.run(i2.get_next()))
- self.assertEqual(7, sess.run(i2.get_next()))
- self.assertEqual(4, sess.run(i1.get_next())) # interleave execution
- self.assertEqual([8, 5], sess.run([i2.get_next(), i1.get_next()]))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(i1.get_next())
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(i2.get_next())
-
- def testCacheTakeRepeat(self):
- dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2)
- itr = dataset.make_one_shot_iterator()
- n = itr.get_next()
-
- expected_values = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
-
- with self.cached_session() as sess:
- for i, expected in enumerate(expected_values):
- self.assertEqual(expected, self.evaluate(n),
- "Unexpected value at index %s" % i)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/cache_test.py b/tensorflow/python/data/kernel_tests/cache_test.py
new file mode 100644
index 0000000..b561cd5
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/cache_test.py
@@ -0,0 +1,253 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.cache()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from os import path
+import shutil
+import tempfile
+
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class FileCacheTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ self.tmp_dir = tempfile.mkdtemp()
+ self.cache_prefix = path.join(self.tmp_dir, "cache")
+
+ def tearDown(self):
+ if self.tmp_dir:
+ shutil.rmtree(self.tmp_dir, ignore_errors=True)
+
+ def testCacheDatasetPassthrough(self):
+ components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
+ np.array([9.0, 10.0, 11.0, 12.0]))
+
+ def dataset_fn(count=5, filename=None):
+ repeat_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
+ if filename:
+ return repeat_dataset.cache(filename)
+ else:
+ return repeat_dataset
+
+ self.assertEqual(
+ tuple([c.shape[1:] for c in components]),
+ dataset_fn().output_shapes)
+
+ get_next = self.getNext(dataset_fn())
+
+ # First run without caching to collect the "ground truth".
+ elements = []
+ for _ in range(20):
+ elements.append(self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ # Assert that the cached dataset has the same elements as the
+ # "ground truth".
+ get_next = self.getNext(dataset_fn(filename=self.cache_prefix))
+ cached_elements = []
+ for _ in range(20):
+ cached_elements.append(self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ self.assertAllEqual(elements, cached_elements)
+
+ # Re-initialize with an empty upstream (to throw errors.OutOfRangeError
+ # if we didn't use the cache).
+ get_next = self.getNext(dataset_fn(count=0, filename=self.cache_prefix))
+ replayed_elements = []
+ for _ in range(20):
+ replayed_elements.append(self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ self.assertEqual(cached_elements, replayed_elements)
+
+ # Re-initialize with an empty upstream and a missing cache file (should
+ # throw errors.OutOfRangeError immediately).
+ get_next = self.getNext(
+ dataset_fn(count=0, filename=self.cache_prefix + "nonsense"))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ def testConcurrentWriters(self):
+ components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
+ np.array([9.0, 10.0, 11.0, 12.0]))
+
+ cache_dataset1 = (
+ dataset_ops.Dataset.from_tensor_slices(components).cache(
+ self.cache_prefix))
+ cache_dataset2 = (
+ dataset_ops.Dataset.from_tensor_slices(components).cache(
+ self.cache_prefix))
+
+ get_next1 = self.getNext(cache_dataset1)
+ get_next2 = self.getNext(cache_dataset2)
+
+ self.evaluate(get_next1()) # this should succeed
+
+ with self.assertRaises(errors.AlreadyExistsError):
+ self.evaluate(get_next2())
+
+ self.evaluate(get_next1()) # this should continue to succeed
+
+ def testConcurrentReaders(self):
+ components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
+ np.array([9.0, 10.0, 11.0, 12.0]))
+
+ cache_dataset1 = (
+ dataset_ops.Dataset.from_tensor_slices(components).cache(
+ self.cache_prefix))
+ cache_dataset2 = (
+ dataset_ops.Dataset.from_tensor_slices(components).cache(
+ self.cache_prefix))
+
+ get_next1 = self.getNext(cache_dataset1)
+ get_next2 = self.getNext(cache_dataset2)
+
+ elements = []
+ for _ in range(4):
+ elements.append(self.evaluate(get_next1()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next1())
+
+ # Re-initialize
+ get_next1 = self.getNext(cache_dataset1)
+ get_next2 = self.getNext(cache_dataset2)
+
+ # Reading concurrently should succeed.
+ elements_itr1 = []
+ elements_itr2 = []
+ elements_itr2.append(self.evaluate(get_next2()))
+ elements_itr1.append(self.evaluate(get_next1()))
+ elements_itr2.append(self.evaluate(get_next2()))
+ elements_itr1.append(self.evaluate(get_next1()))
+ # Intentionally reversing the order
+ elements_itr1.append(self.evaluate(get_next1()))
+ elements_itr2.append(self.evaluate(get_next2()))
+ elements_itr1.append(self.evaluate(get_next1()))
+ elements_itr2.append(self.evaluate(get_next2()))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next2())
+
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next1())
+
+ self.assertAllEqual(elements, elements_itr1)
+ self.assertAllEqual(elements, elements_itr2)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class MemoryCacheTest(test_base.DatasetTestBase):
+
+ def testCacheDatasetPassthrough(self):
+ with ops.device("cpu:0"):
+ repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64))
+ dataset = dataset_ops.Dataset.range(3).flat_map(
+ lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count))
+
+ cached_dataset = dataset.cache().repeat(2)
+ uncached_dataset = dataset.repeat(2)
+
+ self.evaluate(repeat_count.initializer)
+ # Needs to be initializable to capture the variable.
+ cached_next = self.getNext(cached_dataset, requires_initialization=True)
+ uncached_next = self.getNext(
+ uncached_dataset, requires_initialization=True)
+ for i in range(3):
+ for _ in range(10):
+ self.assertEqual(self.evaluate(cached_next()), i)
+ self.assertEqual(self.evaluate(uncached_next()), i)
+
+ self.evaluate(repeat_count.assign(0))
+
+ # The uncached iterator should now be empty.
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(uncached_next())
+
+ # The cached iterator replays from cache.
+ for i in range(3):
+ for _ in range(10):
+ self.assertEqual(self.evaluate(cached_next()), i)
+
+ # The cached iterator should now be empty.
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(cached_next())
+
+ def testEmptyCacheReading(self):
+ components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
+ np.array([9.0, 10.0, 11.0, 12.0]))
+
+ repeat_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(0))
+ cache_dataset = repeat_dataset.cache()
+
+ # Create initialization ops for iterators without and with
+ # caching, respectively.
+ self.assertDatasetProduces(cache_dataset, expected_output=[])
+
+ def testConcurrentReaders(self):
+
+ dataset = dataset_ops.Dataset.range(5).cache()
+ d1 = dataset.map(lambda x: x + 1)
+ d2 = dataset.map(lambda x: x + 6)
+
+ get_next1 = self.getNext(d1)
+
+ self.assertEqual(1, self.evaluate(get_next1()))
+ self.assertEqual(2, self.evaluate(get_next1()))
+ self.assertEqual(3, self.evaluate(get_next1()))
+
+ get_next2 = self.getNext(d2)
+
+ self.assertEqual(6, self.evaluate(get_next2()))
+ self.assertEqual(7, self.evaluate(get_next2()))
+ self.assertEqual(4, self.evaluate(get_next1())) # interleave execution
+ self.assertEqual([8, 5],
+ [self.evaluate(get_next2()),
+ self.evaluate(get_next1())])
+ self.assertEqual(9, self.evaluate(get_next2()))
+ self.assertEqual(10, self.evaluate(get_next2()))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next2())
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next1())
+
+ def testCacheTakeRepeat(self):
+ dataset = dataset_ops.Dataset.range(10).cache().take(5).repeat(2)
+
+ expected_output = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_test.py
similarity index 74%
rename from tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
rename to tensorflow/python/data/kernel_tests/concatenate_test.py
index a0ef69f..5d8bfdc 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.Dataset.concatenate()."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -24,10 +24,12 @@
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
-class ConcatenateDatasetTest(test_base.DatasetTestBase):
+@test_util.run_all_in_graph_and_eager_modes
+class ConcatenateTest(test_base.DatasetTestBase):
def testConcatenateDataset(self):
input_components = (
@@ -46,23 +48,19 @@
self.assertEqual(concatenated.output_shapes, (tensor_shape.TensorShape(
[20]), tensor_shape.TensorShape([15]), tensor_shape.TensorShape([])))
- iterator = concatenated.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
+ get_next = self.getNext(concatenated)
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(9):
- result = self.evaluate(get_next)
- if i < 4:
- for component, result_component in zip(input_components, result):
- self.assertAllEqual(component[i], result_component)
- else:
- for component, result_component in zip(to_concatenate_components,
- result):
- self.assertAllEqual(component[i - 4], result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ for i in range(9):
+ result = self.evaluate(get_next())
+ if i < 4:
+ for component, result_component in zip(input_components, result):
+ self.assertAllEqual(component[i], result_component)
+ else:
+ for component, result_component in zip(to_concatenate_components,
+ result):
+ self.assertAllEqual(component[i - 4], result_component)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
def testConcatenateDatasetDifferentShape(self):
input_components = (
@@ -79,24 +77,18 @@
self.assertEqual(
[ts.as_list()
for ts in nest.flatten(concatenated.output_shapes)], [[20], [None]])
-
- iterator = concatenated.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(9):
- result = self.evaluate(get_next)
- if i < 4:
- for component, result_component in zip(input_components, result):
- self.assertAllEqual(component[i], result_component)
- else:
- for component, result_component in zip(to_concatenate_components,
- result):
- self.assertAllEqual(component[i - 4], result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ get_next = self.getNext(concatenated)
+ for i in range(9):
+ result = self.evaluate(get_next())
+ if i < 4:
+ for component, result_component in zip(input_components, result):
+ self.assertAllEqual(component[i], result_component)
+ else:
+ for component, result_component in zip(to_concatenate_components,
+ result):
+ self.assertAllEqual(component[i - 4], result_component)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
def testConcatenateDatasetDifferentStructure(self):
input_components = (
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/dataset_checkpoint_test.py
similarity index 72%
rename from tensorflow/python/data/kernel_tests/range_dataset_op_test.py
rename to tensorflow/python/data/kernel_tests/dataset_checkpoint_test.py
index fcb025c..cdaa4fd 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_checkpoint_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Test RangeDataset."""
+"""Checkpoint tests for `tf.data.Dataset`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -26,7 +26,6 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
@@ -35,51 +34,7 @@
from tensorflow.python.platform import test
-@test_util.run_all_in_graph_and_eager_modes
-class RangeDatasetTest(test_base.DatasetTestBase):
-
- def testStop(self):
- dataset = dataset_ops.Dataset.range(5)
- self.assertDatasetProduces(dataset, expected_output=range(5))
-
- def testStartStop(self):
- start, stop = 2, 5
- dataset = dataset_ops.Dataset.range(start, stop)
- self.assertDatasetProduces(dataset, expected_output=range(2, 5))
-
- def testStartStopStep(self):
- start, stop, step = 2, 10, 2
- dataset = dataset_ops.Dataset.range(start, stop, step)
- self.assertDatasetProduces(dataset, expected_output=range(2, 10, 2))
-
- def testZeroStep(self):
- start, stop, step = 2, 10, 0
- dataset = dataset_ops.Dataset.range(start, stop, step)
- self.assertDatasetProduces(
- dataset, expected_error=(errors.InvalidArgumentError, ""))
-
- def testNegativeStep(self):
- start, stop, step = 2, 10, -1
- dataset = dataset_ops.Dataset.range(start, stop, step)
- self.assertDatasetProduces(dataset, expected_output=range(2, 10, -1))
-
- def testStopLessThanStart(self):
- start, stop = 10, 2
- dataset = dataset_ops.Dataset.range(start, stop)
- self.assertDatasetProduces(dataset, expected_output=range(10, 2))
-
- def testStopLessThanStartWithPositiveStep(self):
- start, stop, step = 10, 2, 2
- dataset = dataset_ops.Dataset.range(start, stop, step)
- self.assertDatasetProduces(dataset, expected_output=range(10, 2, 2))
-
- def testStopLessThanStartWithNegativeStep(self):
- start, stop, step = 10, 2, -1
- dataset = dataset_ops.Dataset.range(start, stop, step)
- self.assertDatasetProduces(dataset, expected_output=range(10, 2, -1))
-
-
-class ExperimentalCheckpointDatasetTest(test_base.DatasetTestBase):
+class DatasetCheckpointTest(test_base.DatasetTestBase):
def tearDown(self):
# Remove all checkpoint files.
@@ -124,19 +79,19 @@
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(variables.global_variables_initializer())
- self.evaluate(init_op)
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
for i in range(start, break_point):
- self.assertEqual(i, self.evaluate(get_next))
- self.evaluate(save_op)
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(init_op)
- self.evaluate(restore_op)
+ sess.run(init_op)
+ sess.run(restore_op)
for i in range(break_point, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -144,14 +99,14 @@
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(variables.global_variables_initializer())
- self.evaluate(init_op)
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
for i in range(start, break_point):
- self.assertEqual(i, self.evaluate(get_next))
- self.evaluate(save_op)
- self.evaluate(restore_op)
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+ sess.run(restore_op)
for i in range(break_point, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -175,14 +130,14 @@
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
with self.session(graph=g) as sess:
- self.evaluate(variables.global_variables_initializer())
- self.evaluate(init_op)
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
for _ in range(break_epoch):
for i in range(start, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
for i in range(start, break_point):
- self.assertEqual(i, self.evaluate(get_next))
- self.evaluate(save_op)
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
with ops.Graph().as_default() as g:
# Create an empty IteratorResource and restore the Iterator into it.
@@ -193,12 +148,12 @@
restore_op = self._restore_op(iterator._iterator_resource)
get_next = iterator.get_next()
with self.session(graph=g) as sess:
- self.evaluate(restore_op)
+ sess.run(restore_op)
for i in range(break_point, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
for _ in range(break_epoch + 1, num_epochs):
for i in range(start, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -221,20 +176,20 @@
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(variables.global_variables_initializer())
- self.evaluate(init_op)
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
for i in range(start, break_point):
- self.assertEqual(i, self.evaluate(get_next))
- self.evaluate(save_op)
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
with ops.Graph().as_default() as g:
# Intentionally build a graph with a different value for stop to make sure
# the original dataset graph is actually getting loaded.
init_op, get_next, _, restore_op = _build_graph(start, stop_1)
with self.session(graph=g) as sess:
- self.evaluate(restore_op)
+ sess.run(restore_op)
for i in range(break_point, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -259,19 +214,19 @@
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(variables.global_variables_initializer())
- self.evaluate(init_op)
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
for i in range(start, break_point):
- self.assertEqual(i, self.evaluate(get_next))
- self.evaluate(save_op)
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(init_op)
- self.evaluate(restore_op)
+ sess.run(init_op)
+ sess.run(restore_op)
for i in range(break_point, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -294,27 +249,27 @@
with ops.Graph().as_default() as g:
init_op, get_next, save_op, _ = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(variables.global_variables_initializer())
- self.evaluate(init_op)
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
for i in range(start, break_point1):
- self.assertEqual(i, self.evaluate(get_next))
- self.evaluate(save_op)
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(restore_op)
+ sess.run(restore_op)
for i in range(break_point1, break_point2):
- self.assertEqual(i, self.evaluate(get_next))
- self.evaluate(save_op)
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
break_point2 = 7
with ops.Graph().as_default() as g:
init_op, get_next, save_op, restore_op = _build_graph(start, stop)
with self.session(graph=g) as sess:
- self.evaluate(restore_op)
+ sess.run(restore_op)
for i in range(break_point2, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -338,28 +293,28 @@
init_op, get_next, save_op, restore_op = _build_graph(
start, stop, num_epochs)
with self.session(graph=g) as sess:
- self.evaluate(variables.global_variables_initializer())
- self.evaluate(init_op)
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
- self.evaluate(restore_op)
+ sess.run(restore_op)
for _ in range(break_epoch - 1):
for i in range(start, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
for i in range(start, break_range):
- self.assertEqual(i, self.evaluate(get_next))
- self.evaluate(save_op)
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
with self.session(graph=g) as sess:
- self.evaluate(restore_op)
+ sess.run(restore_op)
for i in range(break_range, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
for _ in range(break_epoch, num_epochs):
for i in range(start, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -381,23 +336,23 @@
init_op, get_next, save_op, restore_op = _build_graph(
start, stop, num_epochs)
with self.session(graph=g) as sess:
- self.evaluate(variables.global_variables_initializer())
- self.evaluate(init_op)
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
# Note: There is no checkpoint saved currently so a NotFoundError is
# raised.
with self.assertRaises(errors.NotFoundError):
- self.evaluate(restore_op)
+ sess.run(restore_op)
for _ in range(num_epochs):
for i in range(start, stop):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- self.evaluate(save_op)
+ sess.run(save_op)
with ops.Graph().as_default() as g:
init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
with self.session(graph=g) as sess:
- self.evaluate(restore_op)
+ sess.run(restore_op)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
deleted file mode 100644
index f7b5008..0000000
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ /dev/null
@@ -1,650 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.platform import test
-
-
-class DatasetConstructorTest(test_base.DatasetTestBase):
-
- def testFromTensors(self):
- """Test a dataset that represents a single tuple of tensors."""
- components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
-
- iterator = (dataset_ops.Dataset.from_tensors(components)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([c.shape for c in components],
- [t.shape for t in get_next])
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- results = self.evaluate(get_next)
- for component, result_component in zip(components, results):
- self.assertAllEqual(component, result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFromTensorsSparse(self):
- """Test a dataset that represents a single tuple of tensors."""
- components = (sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([0]),
- dense_shape=np.array([1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0], [1, 1]]),
- values=np.array([-1, 1]),
- dense_shape=np.array([2, 2])))
-
- iterator = (
- dataset_ops.Dataset.from_tensors(components)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual(
- [tensor_shape.TensorShape(c.dense_shape) for c in components],
- [shape for shape in iterator.output_shapes])
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- results = self.evaluate(get_next)
- for component, result_component in zip(components, results):
- self.assertSparseValuesEqual(component, result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFromTensorsMixed(self):
- """Test an dataset that represents a single tuple of tensors."""
- components = (np.array(1), np.array([1, 2, 3]), np.array(37.0),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([0]),
- dense_shape=np.array([1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0], [1, 1]]),
- values=np.array([-1, 1]),
- dense_shape=np.array([2, 2])))
-
- iterator = (
- dataset_ops.Dataset.from_tensors(components)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([
- tensor_shape.TensorShape(c.dense_shape)
- if sparse_tensor.is_sparse(c) else c.shape for c in components
- ], [shape for shape in iterator.output_shapes])
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- results = self.evaluate(get_next)
- for component, result_component in zip(components, results):
- if sparse_tensor.is_sparse(component):
- self.assertSparseValuesEqual(component, result_component)
- else:
- self.assertAllEqual(component, result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFromTensorSlices(self):
- """Test a dataset that represents the slices from a tuple of tensors."""
- components = (
- np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
- np.array([[12], [13], [14], [15]]), 22),
- np.array([37.0, 38.0, 39.0, 40.0])
- )
-
- iterator = (dataset_ops.Dataset.from_tensor_slices(components)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([c.shape[1:] for c in components],
- [t.shape for t in get_next])
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(4):
- results = self.evaluate(get_next)
- for component, result_component in zip(components, results):
- self.assertAllEqual(component[i], result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFromTensorSlicesSparse(self):
- """Test a dataset that represents the slices from a tuple of tensors."""
- components = (sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0], [1, 0], [2, 0]]),
- values=np.array([0, 0, 0]),
- dense_shape=np.array([3, 1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0], [1, 1], [2, 2]]),
- values=np.array([1, 2, 3]),
- dense_shape=np.array([3, 3])))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual(
- [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components],
- [shape for shape in iterator.output_shapes])
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- expected = [
- (sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([0]),
- dense_shape=np.array([1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([1]),
- dense_shape=np.array([3]))),
- (sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([0]),
- dense_shape=np.array([1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[1]]),
- values=np.array([2]),
- dense_shape=np.array([3]))),
- (sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([0]),
- dense_shape=np.array([1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[2]]),
- values=np.array([3]),
- dense_shape=np.array([3]))),
- ]
- for i in range(3):
- results = self.evaluate(get_next)
- for component, result_component in zip(expected[i], results):
- self.assertSparseValuesEqual(component, result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFromTensorSlicesMixed(self):
- """Test a dataset that represents the slices from a tuple of tensors."""
- components = (np.tile(np.array([[1], [2], [3]]), 20),
- np.tile(np.array([[12], [13], [14]]), 22),
- np.array([37.0, 38.0, 39.0]),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0], [1, 0], [2, 0]]),
- values=np.array([0, 0, 0]),
- dense_shape=np.array([3, 1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0], [1, 1], [2, 2]]),
- values=np.array([1, 2, 3]),
- dense_shape=np.array([3, 3])))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([
- tensor_shape.TensorShape(c.dense_shape[1:])
- if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components
- ], [shape for shape in iterator.output_shapes])
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- expected = [
- (sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([0]),
- dense_shape=np.array([1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([1]),
- dense_shape=np.array([3]))),
- (sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([0]),
- dense_shape=np.array([1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[1]]),
- values=np.array([2]),
- dense_shape=np.array([3]))),
- (sparse_tensor.SparseTensorValue(
- indices=np.array([[0]]),
- values=np.array([0]),
- dense_shape=np.array([1])),
- sparse_tensor.SparseTensorValue(
- indices=np.array([[2]]),
- values=np.array([3]),
- dense_shape=np.array([3]))),
- ]
- for i in range(3):
- results = self.evaluate(get_next)
- for component, result_component in zip(
- (list(zip(*components[:3]))[i] + expected[i]), results):
- if sparse_tensor.is_sparse(component):
- self.assertSparseValuesEqual(component, result_component)
- else:
- self.assertAllEqual(component, result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFromTensorSlicesWithDict(self):
- components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
- iterator = (dataset_ops.Dataset.from_tensor_slices(components)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual(dtypes.int32, iterator.output_types["foo"])
- self.assertEqual(dtypes.float32, iterator.output_types["bar"])
- self.assertEqual((), iterator.output_shapes["foo"])
- self.assertEqual((1,), iterator.output_shapes["bar"])
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(3):
- results = self.evaluate(get_next)
- self.assertEqual(components["foo"][i], results["foo"])
- self.assertEqual(components["bar"][i], results["bar"])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFromSparseTensorSlices(self):
- """Test a dataset based on slices of a `tf.SparseTensor`."""
- st = array_ops.sparse_placeholder(dtypes.float64)
- iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
-
- with self.cached_session() as sess:
- slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
-
- # Test with sparse tensor in the appropriate order.
- indices = np.array(
- [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))])
- values = np.array([val for s in slices for val in s])
- dense_shape = np.array([len(slices), max(len(s) for s in slices) + 1])
- sparse_feed = sparse_tensor.SparseTensorValue(indices, values,
- dense_shape)
- sess.run(init_op, feed_dict={st: sparse_feed})
- for i, s in enumerate(slices):
- results = self.evaluate(get_next)
- self.assertAllEqual(s, results.values)
- expected_indices = np.array(
- [[j] for j in range(len(slices[i]))]).reshape([-1, 1])
- self.assertAllEqual(expected_indices, results.indices)
- self.assertAllEqual(dense_shape[1:], results.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test with sparse tensor in the reverse order, which is not
- # currently supported.
- reverse_order_indices = indices[::-1, :]
- reverse_order_values = values[::-1]
- sparse_feed = sparse_tensor.SparseTensorValue(
- reverse_order_indices, reverse_order_values, dense_shape)
- with self.assertRaises(errors.UnimplementedError):
- sess.run(init_op, feed_dict={st: sparse_feed})
-
- # Test with an empty sparse tensor.
- empty_indices = np.empty((0, 4), dtype=np.int64)
- empty_values = np.empty((0,), dtype=np.float64)
- empty_dense_shape = [0, 4, 37, 9]
- sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
- empty_dense_shape)
- sess.run(init_op, feed_dict={st: sparse_feed})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # pylint: disable=g-long-lambda,unnecessary-lambda
- def testNestedStructure(self):
- components = (np.array([1, 2, 3], dtype=np.int64),
- (np.array([4., 5.]), np.array([6., 7.])),
- np.array([8, 9, 10], dtype=np.int64))
-
- dataset = dataset_ops.Dataset.from_tensors(components)
- self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
- dtypes.int64), dataset.output_types)
- self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
-
- dataset = dataset.shuffle(10, 10)
- self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
- dtypes.int64), dataset.output_types)
- self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
-
- dataset = dataset.repeat(-1)
- self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
- dtypes.int64), dataset.output_types)
- self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
-
- dataset = dataset.filter(lambda x, y, z: True)
- self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
- dtypes.int64), dataset.output_types)
- self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
-
- dataset = dataset.take(5)
- self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
- dtypes.int64), dataset.output_types)
- self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
-
- dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
- self.assertEquals(((dtypes.int64, dtypes.int64),
- (dtypes.float64, dtypes.float64)), dataset.output_types)
- self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
-
- dataset = dataset.flat_map(
- lambda x, y: dataset_ops.Dataset.from_tensors(((x[0], x[1]),
- (y[0], y[1])))
- )
- self.assertEquals(((dtypes.int64, dtypes.int64),
- (dtypes.float64, dtypes.float64)), dataset.output_types)
- self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
-
- dataset = dataset.batch(32)
- self.assertEquals(((dtypes.int64, dtypes.int64),
- (dtypes.float64, dtypes.float64)), dataset.output_types)
- self.assertEquals((([None, 3], [None, 3]), ([None, 2], [None, 2])),
- nest.pack_sequence_as(dataset.output_shapes, [
- s.as_list()
- for s in nest.flatten(dataset.output_shapes)
- ]))
-
- iterator = dataset.make_one_shot_iterator()
- (w, x), (y, z) = iterator.get_next()
- self.assertEquals(dtypes.int64, w.dtype)
- self.assertEquals(dtypes.int64, x.dtype)
- self.assertEquals(dtypes.float64, y.dtype)
- self.assertEquals(dtypes.float64, z.dtype)
- self.assertEquals([None, 3], w.shape.as_list())
- self.assertEquals([None, 3], x.shape.as_list())
- self.assertEquals([None, 2], y.shape.as_list())
- self.assertEquals([None, 2], z.shape.as_list())
-
- iterator = dataset.make_initializable_iterator()
- (w, x), (y, z) = iterator.get_next()
- self.assertEquals(dtypes.int64, w.dtype)
- self.assertEquals(dtypes.int64, x.dtype)
- self.assertEquals(dtypes.float64, y.dtype)
- self.assertEquals(dtypes.float64, z.dtype)
- self.assertEquals([None, 3], w.shape.as_list())
- self.assertEquals([None, 3], x.shape.as_list())
- self.assertEquals([None, 2], y.shape.as_list())
- self.assertEquals([None, 2], z.shape.as_list())
-
- # Define a separate set of components with matching leading
- # dimension for the from-slices constructor.
- components_for_slices = (np.array([1, 2, 3], dtype=np.int64),
- (np.array([4., 5., 6.]),
- np.array([7., 8., 9.])),
- np.array([10, 11, 12], dtype=np.int64))
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
- self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
- dtypes.int64), dataset.output_types)
- self.assertEquals(([], ([], []), []), dataset.output_shapes)
-
- def testNestedDict(self):
- components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
- dataset = dataset_ops.Dataset.from_tensors(components)
- self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"])
- self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"])
- self.assertEquals(dtypes.int32, dataset.output_types["b"])
- self.assertEquals([], dataset.output_shapes["a"]["aa"])
- self.assertEquals([2], dataset.output_shapes["a"]["ab"])
- self.assertEquals([3], dataset.output_shapes["b"])
-
- def testNonSequenceNestedStructure(self):
- components = np.array([1, 2, 3], dtype=np.int64)
-
- dataset = dataset_ops.Dataset.from_tensors(components)
- self.assertEquals(dtypes.int64, dataset.output_types)
- self.assertEquals([3], dataset.output_shapes)
-
- dataset = dataset.filter(
- lambda x: math_ops.reduce_all(math_ops.equal(x, components)))
- self.assertEquals(dtypes.int64, dataset.output_types)
- self.assertEquals([3], dataset.output_shapes)
-
- dataset = dataset.map(lambda x: array_ops.stack([x, x]))
- self.assertEquals(dtypes.int64, dataset.output_types)
- self.assertEquals([2, 3], dataset.output_shapes)
-
- dataset = dataset.flat_map(
- lambda x: dataset_ops.Dataset.from_tensor_slices(x))
- self.assertEquals(dtypes.int64, dataset.output_types)
- self.assertEquals([3], dataset.output_shapes)
-
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- self.assertEquals(dtypes.int64, get_next.dtype)
- self.assertEquals([3], get_next.shape)
-
- def testSplitPipelineFailsWithPlacementError(self):
- with session.Session(
- target="",
- config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
-
- dataset = dataset_ops.Dataset.from_tensors(0)
-
- # Define a pipeline that attempts to use variables on two
- # different devices.
- #
- # Initialize the variables before creating to iterator, to avoid the
- # placement algorithm overriding the DT_RESOURCE colocation constraints.
- with ops.device("/cpu:0"):
- var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
- dataset = dataset.map(lambda x: x + var_0.read_value())
- self.evaluate(var_0.initializer)
-
- with ops.device("/cpu:1"):
- var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
- dataset = dataset.map(lambda x: x + var_1.read_value())
- self.evaluate(var_1.initializer)
-
- iterator = dataset.make_initializable_iterator()
- self.evaluate(iterator.initializer)
-
- with self.assertRaisesRegexp(
- errors.FailedPreconditionError,
- "Error while reading resource variable Variable"):
- sess.run(iterator.get_next())
-
-
-class DatasetConstructorBenchmark(test.Benchmark):
-
- def benchmarkSliceRepeatBatch(self):
- input_size = 10000
- batch_size = 100
- num_epochs = 100
-
- input_data = np.random.randn(input_size)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_data)
- .repeat(num_epochs + 1).batch(batch_size))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- self.evaluate(iterator.initializer)
- # Run one whole epoch to burn in the computation.
- for _ in range(input_size // batch_size):
- sess.run(next_element)
- deltas = []
- try:
- while True:
- start = time.time()
- sess.run(next_element)
- deltas.append(time.time() - start)
- except errors.OutOfRangeError:
- pass
-
- median_wall_time = np.median(deltas)
- print("Slice/repeat/batch with sess.run() input size: %d batch size: %d "
- "Median wall time per element: %f" % (input_size, batch_size,
- median_wall_time))
- self.report_benchmark(
- iters=len(deltas),
- wall_time=median_wall_time,
- name="benchmark_slice_repeat_batch_input_%d_batch_%d" % (input_size,
- batch_size))
-
- def benchmarkSliceRepeatBatchCallable(self):
- input_size = 10000
- batch_size = 100
- num_epochs = 100
-
- input_data = np.random.randn(input_size)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_data)
- .repeat(num_epochs + 1).batch(batch_size))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- self.evaluate(iterator.initializer)
- get_next_element = sess.make_callable(next_element)
- # Run one whole epoch to burn in the computation.
- for _ in range(input_size // batch_size):
- get_next_element()
- deltas = []
- try:
- while True:
- start = time.time()
- get_next_element()
- deltas.append(time.time() - start)
- except errors.OutOfRangeError:
- pass
-
- median_wall_time = np.median(deltas)
- print(
- "Slice/repeat/batch with callable input size: %d batch size: %d Median"
- " wall time per element: %f" % (input_size, batch_size,
- median_wall_time))
- self.report_benchmark(
- iters=len(deltas),
- wall_time=median_wall_time,
- name="benchmark_slice_repeat_batch_callable_input_%d_batch_%d" %
- (input_size, batch_size))
-
- def benchmarkReshapeSliceRepeatCallable(self):
- input_size = 10000
- batch_size = 100
- num_epochs = 100
-
- input_data = np.random.randn(input_size)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_data.reshape(100, 100))
- .repeat(num_epochs + 1))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- self.evaluate(iterator.initializer)
- get_next_element = sess.make_callable(next_element)
- # Run one whole epoch to burn in the computation.
- for _ in range(input_size // batch_size):
- get_next_element()
- deltas = []
- try:
- while True:
- start = time.time()
- get_next_element()
- deltas.append(time.time() - start)
- except errors.OutOfRangeError:
- pass
-
- median_wall_time = np.median(deltas)
- print("Reshape/slice/repeat with callable input size: %d batch size: %d "
- "Median wall time per element: %f" % (input_size, batch_size,
- median_wall_time))
- self.report_benchmark(
- iters=len(deltas),
- wall_time=median_wall_time,
- name="benchmark_reshape_slice_repeat_callable_input_%d_batch_%d" %
- (input_size, batch_size))
-
- def benchmarkSliceBatchCacheRepeatCallable(self):
- input_size = 10000
- batch_size = 100
- num_epochs = 100
-
- input_data = np.random.randn(input_size)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_data).batch(batch_size)
- .cache().repeat(num_epochs + 1))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- self.evaluate(iterator.initializer)
- get_next_element = sess.make_callable(next_element)
- # Run one whole epoch to burn in the computation.
- for _ in range(input_size // batch_size):
- get_next_element()
- deltas = []
- try:
- while True:
- start = time.time()
- get_next_element()
- deltas.append(time.time() - start)
- except errors.OutOfRangeError:
- pass
-
- median_wall_time = np.median(deltas)
- print(
- "Slice/batch/cache/repeat with callable input size: %d batch size: %d "
- "Median wall time per element: %f"
- % (input_size, batch_size, median_wall_time))
- self.report_benchmark(
- iters=len(deltas),
- wall_time=median_wall_time,
- name="benchmark_slice_batch_cache_repeat_callable_input_%d_batch_%d" %
- (input_size, batch_size))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_test.py
similarity index 71%
rename from tensorflow/python/data/kernel_tests/dataset_ops_test.py
rename to tensorflow/python/data/kernel_tests/dataset_test.py
index a5324af..7dbab60 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the input pipeline ops."""
+"""Tests for `tf.data.Dataset`."""
from __future__ import absolute_import
from __future__ import division
@@ -24,21 +24,26 @@
from tensorflow.core.framework import graph_pb2
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import structure
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
-class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
+@test_util.run_all_in_graph_and_eager_modes
+class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
- with self.cached_session() as sess:
- graph = graph_pb2.GraphDef().FromString(
- sess.run(dataset._as_serialized_graph()))
- self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
+ graph = graph_pb2.GraphDef().FromString(
+ self.evaluate(dataset._as_serialized_graph()))
+ self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
@staticmethod
def make_apply_fn(dataset):
@@ -76,7 +81,7 @@
lambda: readers.FixedLengthRecordDataset("", 42)),
("FromGenerator",
lambda: dataset_ops.Dataset.from_generator(
- DatasetOpsTest.make_gen(), dtypes.int32),
+ DatasetTest.make_gen(), dtypes.int32),
1),
("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
@@ -235,7 +240,7 @@
options2 = dataset_ops.Options()
options2.experimental_autotune = False
with self.assertRaisesRegexp(ValueError,
- "Cannot merge incompatible values of option"):
+ "Cannot merge incompatible values"):
dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
def testOptionsMergeOptionsFromMultipleInputs(self):
@@ -249,6 +254,64 @@
self.assertTrue(ds.options().experimental_autotune)
self.assertTrue(ds.options().experimental_filter_fusion)
+ # TODO(b/119882922): use-after-free bug in eager mode.
+ # pylint: disable=g-long-lambda
+ @parameterized.named_parameters(
+ ("Tensor", lambda: constant_op.constant(37.0),
+ structure.TensorStructure(dtypes.float32, [])),
+ ("SparseTensor", lambda: sparse_tensor.SparseTensor(
+ indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
+ dense_shape=[1]),
+ structure.SparseTensorStructure(dtypes.int32, [1])),
+ ("Nest", lambda: {
+ "a": constant_op.constant(37.0),
+ "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
+ structure.NestedStructure({
+ "a": structure.TensorStructure(dtypes.float32, []),
+ "b": (structure.TensorStructure(dtypes.string, [1]),
+ structure.TensorStructure(dtypes.string, []))})),
+ ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
+ constant_op.constant([1, 2, 3])),
+ dataset_ops.DatasetStructure(
+ structure.TensorStructure(dtypes.int32, []))),
+ ("Optional", lambda: optional_ops.Optional.from_value(37.0),
+ optional_ops.OptionalStructure(
+ structure.TensorStructure(dtypes.float32, []))),
+ )
+ def testSkipEagerDatasetStructure(self, tf_value_fn,
+ expected_element_structure):
+ dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
+ dataset_structure = structure.Structure.from_value(dataset)
+ self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)
+
+ # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
+ # the element structure.
+ self.assertTrue(expected_element_structure.is_compatible_with(
+ dataset_structure._element_structure))
+ self.assertTrue(dataset_structure._element_structure.is_compatible_with(
+ expected_element_structure))
+
+ self.assertEqual([dtypes.variant], dataset_structure._flat_types)
+ self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)
+
+ # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
+ # and _to_tensor_list().
+ round_trip_dataset = dataset_structure._from_tensor_list(
+ dataset_structure._to_tensor_list(dataset))
+
+ value = tf_value_fn()
+
+ if isinstance(value, dataset_ops.Dataset):
+ self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
+ elif isinstance(value, optional_ops.Optional):
+ self.assertDatasetProduces(
+ round_trip_dataset.map(lambda opt: opt.get_value()),
+ [self.evaluate(value.get_value())],
+ requires_initialization=True)
+ else:
+ self.assertDatasetProduces(
+ round_trip_dataset, [self.evaluate(tf_value_fn())],
+ requires_initialization=True)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
deleted file mode 100644
index 5ddb222..0000000
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ /dev/null
@@ -1,220 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class FilterDatasetTest(test_base.DatasetTestBase):
-
- def testFilterDataset(self):
- components = (
- np.arange(7, dtype=np.int64),
- np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
- 7, dtype=np.int64)[:, np.newaxis],
- np.array(37.0, dtype=np.float64) * np.arange(7)
- )
- count = array_ops.placeholder(dtypes.int64, shape=[])
- modulus = array_ops.placeholder(dtypes.int64)
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(count)
- .filter(lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([c.shape[1:] for c in components],
- [t.shape for t in get_next])
-
- with self.cached_session() as sess:
- # Test that we can dynamically feed a different modulus value for each
- # iterator.
- def do_test(count_val, modulus_val):
- sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val})
- for _ in range(count_val):
- for i in [x for x in range(7) if x**2 % modulus_val == 0]:
- result = self.evaluate(get_next)
- for component, result_component in zip(components, result):
- self.assertAllEqual(component[i]**2, result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- do_test(14, 2)
- do_test(4, 18)
-
- # Test an empty dataset.
- do_test(0, 1)
-
- def testFilterRange(self):
- dataset = dataset_ops.Dataset.range(100).filter(
- lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertEqual(0, self.evaluate(get_next))
- self.assertEqual(1, self.evaluate(get_next))
- self.assertEqual(3, self.evaluate(get_next))
-
- def testFilterDict(self):
- iterator = (dataset_ops.Dataset.range(10)
- .map(lambda x: {"foo": x * 2, "bar": x ** 2})
- .filter(lambda d: math_ops.equal(d["bar"] % 2, 0))
- .map(lambda d: d["foo"] + d["bar"])
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(10):
- if (i ** 2) % 2 == 0:
- self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testUseStepContainerInFilter(self):
- input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
-
- # Define a predicate that returns true for the first element of
- # the sequence and not the second, and uses `tf.map_fn()`.
- def _predicate(xs):
- squared_xs = functional_ops.map_fn(lambda x: x * x, xs)
- summed = math_ops.reduce_sum(squared_xs)
- return math_ops.equal(summed, 1 + 4 + 9)
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]])
- .filter(_predicate)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- self.assertAllEqual(input_data[0], self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSparse(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1])),
- dense_shape=np.array([1, 1])), i
-
- def _filter_fn(_, i):
- return math_ops.equal(i % 2, 0)
-
- iterator = (
- dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
- lambda x, i: x).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(5):
- actual = self.evaluate(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
- self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testShortCircuit(self):
- iterator = (
- dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.range(10),
- dataset_ops.Dataset.from_tensors(True).repeat(None)))
- .filter(lambda x, y: y).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(10):
- self.assertEqual((i, True), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testParallelFilters(self):
- dataset = dataset_ops.Dataset.range(10).filter(
- lambda x: math_ops.equal(x % 2, 0))
- iterators = [dataset.make_one_shot_iterator() for _ in range(10)]
- next_elements = [iterator.get_next() for iterator in iterators]
- with self.cached_session() as sess:
- self.assertEqual([0 for _ in range(10)], self.evaluate(next_elements))
-
-
-class FilterDatasetBenchmark(test.Benchmark):
-
- def _benchmark(self, predicate, name):
- with ops.Graph().as_default():
- dataset = (
- dataset_ops.Dataset.from_tensors(True).repeat(None).filter(predicate))
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- print("Filter dataset using %s. Median wall time: %f" %
- (name, median_wall_time))
- self.report_benchmark(
- iters=100,
- wall_time=median_wall_time,
- name="benchmark_filter_dataset_%s" % name)
-
- def benchmarkSimpleFunction(self):
- self._benchmark(array_ops.identity, "simple_function")
-
- def benchmarkReturnComponentOptimization(self):
- self._benchmark(lambda x: x, "return_component")
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/filter_test.py b/tensorflow/python/data/kernel_tests/filter_test.py
new file mode 100644
index 0000000..afaf954
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/filter_test.py
@@ -0,0 +1,128 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.filter()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class FilterTest(test_base.DatasetTestBase):
+
+ def testFilterDataset(self):
+ components = (
+ np.arange(7, dtype=np.int64),
+ np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
+ 7, dtype=np.int64)[:, np.newaxis],
+ np.array(37.0, dtype=np.float64) * np.arange(7)
+ )
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ def do_test(count, modulus):
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn).repeat(count).filter(
+ lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0))
+ self.assertEqual([c.shape[1:] for c in components],
+ [shape for shape in dataset.output_shapes])
+ get_next = self.getNext(dataset)
+ for _ in range(count):
+ for i in [x for x in range(7) if x**2 % modulus == 0]:
+ result = self.evaluate(get_next())
+ for component, result_component in zip(components, result):
+ self.assertAllEqual(component[i]**2, result_component)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ do_test(14, 2)
+ do_test(4, 18)
+
+ # Test an empty dataset.
+ do_test(0, 1)
+
+ def testFilterRange(self):
+ dataset = dataset_ops.Dataset.range(4).filter(
+ lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
+ self.assertDatasetProduces(dataset, expected_output=[0, 1, 3])
+
+ def testFilterDict(self):
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda x: {"foo": x * 2, "bar": x ** 2}).filter(
+ lambda d: math_ops.equal(d["bar"] % 2, 0)).map(
+ lambda d: d["foo"] + d["bar"])
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2])
+
+ def testUseStepContainerInFilter(self):
+ input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
+
+ # Define a predicate that returns true for the first element of
+ # the sequence and not the second, and uses `tf.map_fn()`.
+ def _predicate(xs):
+ squared_xs = functional_ops.map_fn(lambda x: x * x, xs)
+ summed = math_ops.reduce_sum(squared_xs)
+ return math_ops.equal(summed, 1 + 4 + 9)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ [[1, 2, 3], [4, 5, 6]]).filter(_predicate)
+ self.assertDatasetProduces(dataset, expected_output=[input_data[0]])
+
+ def testSparse(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1])), i
+
+ def _filter_fn(_, i):
+ return math_ops.equal(i % 2, 0)
+
+ dataset = dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
+ lambda x, i: x)
+ self.assertDatasetProduces(
+ dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)])
+
+ def testShortCircuit(self):
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.from_tensors(True).repeat(None)
+ )).filter(lambda x, y: y)
+ self.assertDatasetProduces(
+ dataset, expected_output=[(i, True) for i in range(10)])
+
+ def testParallelFilters(self):
+ dataset = dataset_ops.Dataset.range(10).filter(
+ lambda x: math_ops.equal(x % 2, 0))
+ next_elements = [self.getNext(dataset) for _ in range(10)]
+ self.assertEqual([0 for _ in range(10)],
+ self.evaluate(
+ [next_element() for next_element in next_elements]))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/fixed_length_record_dataset_test.py b/tensorflow/python/data/kernel_tests/fixed_length_record_dataset_test.py
new file mode 100644
index 0000000..9503e57
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/fixed_length_record_dataset_test.py
@@ -0,0 +1,171 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.FixedLengthRecordDataset`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class FixedLengthRecordDatasetTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ super(FixedLengthRecordDatasetTest, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self._header_bytes = 5
+ self._record_bytes = 3
+ self._footer_bytes = 2
+
+ def _record(self, f, r):
+ return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
+
+ def _createFiles(self, compression_type=None):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
+ filenames.append(fn)
+
+ contents = []
+ contents.append(b"H" * self._header_bytes)
+ for j in range(self._num_records):
+ contents.append(self._record(i, j))
+ contents.append(b"F" * self._footer_bytes)
+ contents = b"".join(contents)
+
+ if not compression_type:
+ with open(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "GZIP":
+ with gzip.GzipFile(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "ZLIB":
+ contents = zlib.compress(contents)
+ with open(fn, "wb") as f:
+ f.write(contents)
+ else:
+ raise ValueError("Unsupported compression_type", compression_type)
+
+ return filenames
+
+ def _testFixedLengthRecordDataset(self, compression_type=None):
+ test_filenames = self._createFiles(compression_type=compression_type)
+
+ def dataset_fn(filenames, num_epochs, batch_size=None):
+ repeat_dataset = readers.FixedLengthRecordDataset(
+ filenames,
+ self._record_bytes,
+ self._header_bytes,
+ self._footer_bytes,
+ compression_type=compression_type).repeat(num_epochs)
+ if batch_size:
+ return repeat_dataset.batch(batch_size)
+ return repeat_dataset
+
+ # Basic test: read from file 0.
+ self.assertDatasetProduces(
+ dataset_fn([test_filenames[0]], 1),
+ expected_output=[
+ self._record(0, i) for i in range(self._num_records)
+ ])
+
+ # Basic test: read from file 1.
+ self.assertDatasetProduces(
+ dataset_fn([test_filenames[1]], 1),
+ expected_output=[
+ self._record(1, i) for i in range(self._num_records)
+ ])
+
+ # Basic test: read from both files.
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ self.assertDatasetProduces(
+ dataset_fn(test_filenames, 1), expected_output=expected_output)
+
+ # Test repeated iteration through both files.
+ get_next = self.getNext(dataset_fn(test_filenames, 10))
+ for _ in range(10):
+ for j in range(self._num_files):
+ for i in range(self._num_records):
+ self.assertEqual(self._record(j, i), self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ # Test batched and repeated iteration through both files.
+ get_next = self.getNext(dataset_fn(test_filenames, 10, self._num_records))
+ for _ in range(10):
+ for j in range(self._num_files):
+ self.assertAllEqual(
+ [self._record(j, i) for i in range(self._num_records)],
+ self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ def testFixedLengthRecordDatasetNoCompression(self):
+ self._testFixedLengthRecordDataset()
+
+ def testFixedLengthRecordDatasetGzipCompression(self):
+ self._testFixedLengthRecordDataset(compression_type="GZIP")
+
+ def testFixedLengthRecordDatasetZlibCompression(self):
+ self._testFixedLengthRecordDataset(compression_type="ZLIB")
+
+ def testFixedLengthRecordDatasetBuffering(self):
+ test_filenames = self._createFiles()
+ dataset = readers.FixedLengthRecordDataset(
+ test_filenames,
+ self._record_bytes,
+ self._header_bytes,
+ self._footer_bytes,
+ buffer_size=10)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testFixedLengthRecordDatasetWrongSize(self):
+ test_filenames = self._createFiles()
+ dataset = readers.FixedLengthRecordDataset(
+ test_filenames,
+ self._record_bytes + 1, # Incorrect record length.
+ self._header_bytes,
+ self._footer_bytes,
+ buffer_size=10)
+ self.assertDatasetProduces(
+ dataset,
+ expected_error=(
+ errors.InvalidArgumentError,
+ r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
+ r"file \".*fixed_length_record.0.txt\" has body length 21 bytes, "
+ r"which is not an exact multiple of the record length \(4 bytes\).")
+ )
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
deleted file mode 100644
index 02979fc..0000000
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ /dev/null
@@ -1,152 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import random
-
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-from tensorflow.python.training import server_lib
-
-
-class FlatMapDatasetTest(test_base.DatasetTestBase):
-
- # pylint: disable=g-long-lambda
- def testFlatMapDataset(self):
- repeats = [1, 2, 3, 4, 5, 0, 1]
- components = np.array(repeats, dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in repeats:
- for _ in range(i):
- self.assertEqual(i, self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testNestedFlatMapDataset(self):
- repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
- components = np.array(repeats, dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x)
- .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y)
- .repeat(y))).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for row in repeats:
- for i in row:
- for _ in range(i):
- self.assertEqual(i, self.evaluate(get_next))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSharedResourceNestedFlatMapDataset(self):
- repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
- components = np.array(repeats, dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x)
- .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y)
- .repeat(y))).make_initializable_iterator(
- shared_name="shared_flat_map_iterator"))
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- # Create two concurrent sessions that share the same iterator
- # resource on the same server, and verify that a random
- # interleaving of `Session.run(get_next)` calls on the two
- # sessions yields the expected result.
- server = server_lib.Server.create_local_server()
- with session.Session(server.target) as sess1:
- with session.Session(server.target) as sess2:
- for _ in range(3):
- sess = random.choice([sess1, sess2])
- self.evaluate(init_op)
- for row in repeats:
- for i in row:
- for _ in range(i):
- sess = random.choice([sess1, sess2])
- self.assertEqual(i, self.evaluate(get_next))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess = random.choice([sess1, sess2])
- sess.run(get_next)
-
- def testMapDict(self):
- iterator = (dataset_ops.Dataset.range(10)
- .map(lambda x: {"foo": x * 2, "bar": x ** 2})
- .flat_map(lambda d: dataset_ops.Dataset.from_tensors(d["foo"])
- .repeat(d["bar"]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(10):
- for _ in range(i ** 2):
- self.assertEqual(i * 2, self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- # pylint: enable=g-long-lambda
-
- def testSparse(self):
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _flat_map_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- iterator = (
- dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- for i in range(10):
- for j in range(2):
- expected = [i, 0] if j % 2 == 0 else [0, -i]
- self.assertAllEqual(expected, self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py
new file mode 100644
index 0000000..5f11c2e
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/flat_map_test.py
@@ -0,0 +1,125 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.flat_map()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import server_lib
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class FlatMapTest(test_base.DatasetTestBase):
+
+ # pylint: disable=g-long-lambda
+ def testFlatMapDataset(self):
+ repeats = [1, 2, 3, 4, 5, 0, 1]
+ components = np.array(repeats, dtype=np.int64)
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).flat_map(
+ lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x))
+ expected_output = []
+ for i in repeats:
+ expected_output.extend([[i]] * i)
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testNestedFlatMapDataset(self):
+ repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
+ components = np.array(repeats, dtype=np.int64)
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).flat_map(
+ lambda x: dataset_ops.Dataset.from_tensor_slices(x).flat_map(
+ lambda y: dataset_ops.Dataset.from_tensors(y).repeat(y))
+ )
+ expected_output = []
+ for row in repeats:
+ for i in row:
+ expected_output.extend([i] * i)
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ # Note: no eager mode coverage, session specific test.
+ def testSkipEagerSharedResourceNestedFlatMapDataset(self):
+ repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
+ components = np.array(repeats, dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x)
+ .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y)
+ .repeat(y))).make_initializable_iterator(
+ shared_name="shared_flat_map_iterator"))
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ # Create two concurrent sessions that share the same iterator
+ # resource on the same server, and verify that a random
+ # interleaving of `Session.run(get_next)` calls on the two
+ # sessions yields the expected result.
+ server = server_lib.Server.create_local_server()
+ with session.Session(server.target) as sess1:
+ with session.Session(server.target) as sess2:
+ for _ in range(3):
+ sess = random.choice([sess1, sess2])
+ sess.run(init_op)
+ for row in repeats:
+ for i in row:
+ for _ in range(i):
+ sess = random.choice([sess1, sess2])
+ self.assertEqual(i, sess.run(get_next))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess = random.choice([sess1, sess2])
+ sess.run(get_next)
+
+ def testMapDict(self):
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda x: {"foo": x * 2, "bar": x ** 2}).flat_map(
+ lambda d: dataset_ops.Dataset.from_tensors(
+ d["foo"]).repeat(d["bar"]))
+ get_next = self.getNext(dataset)
+ for i in range(10):
+ for _ in range(i**2):
+ self.assertEqual(i * 2, self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ def testSparse(self):
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _flat_map_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ dataset = dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
+ expected_output = []
+ for i in range(10):
+ for j in range(2):
+ expected_output.append([i, 0] if j % 2 == 0 else [0, -i])
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/from_generator_test.py
similarity index 88%
rename from tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
rename to tensorflow/python/data/kernel_tests/from_generator_test.py
index 7087b4d..4d82c21 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/from_generator_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for tf.data.Dataset.from_generator()."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -31,7 +31,7 @@
from tensorflow.python.platform import test
-class DatasetConstructorTest(test_base.DatasetTestBase):
+class FromGeneratorTest(test_base.DatasetTestBase):
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
output_types=None):
@@ -47,10 +47,10 @@
with self.cached_session() as sess:
for _ in range(2): # Run twice to test reinitialization.
- self.evaluate(init_op)
+ sess.run(init_op)
for _ in range(num_repeats):
for elem in elem_sequence:
- self.assertAllEqual(elem, self.evaluate(get_next))
+ self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -65,7 +65,7 @@
with self.cached_session() as sess:
for _ in range(num_repeats):
for elem in elem_sequence:
- self.assertAllEqual(elem, self.evaluate(get_next))
+ self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -133,10 +133,10 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for _ in range(num_inner_repeats * num_outer_repeats):
for elem in input_list:
- val0, val1 = self.evaluate(get_next)
+ val0, val1 = sess.run(get_next)
self.assertAllEqual(elem[0], val0)
self.assertAllEqual(elem[1], val1)
with self.assertRaises(errors.OutOfRangeError):
@@ -192,10 +192,10 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for elem in [0, 1]:
for _ in range(num_parallel_iterators):
- self.assertAllEqual(elem, self.evaluate(get_next))
+ self.assertAllEqual(elem, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -215,9 +215,9 @@
self.assertEqual(dtype, get_next.dtype)
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for expected in [[1], [2], [3]]:
- next_val = self.evaluate(get_next)
+ next_val = sess.run(get_next)
self.assertEqual(dtype.as_numpy_dtype, next_val.dtype)
self.assertAllEqual(expected, next_val)
with self.assertRaises(errors.OutOfRangeError):
@@ -236,9 +236,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for expected in [b"foo", b"bar", b"baz"]:
- next_val = self.evaluate(get_next)
+ next_val = sess.run(get_next)
self.assertAllEqual(expected, next_val)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -257,12 +257,12 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
- self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
- self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
+ sess.run(init_op)
+ self.assertAllEqual([1, 2, 3], sess.run(get_next))
+ self.assertAllEqual([4, 5, 6], sess.run(get_next))
with self.assertRaisesOpError("The expected type was int64"):
sess.run(get_next)
- self.assertAllEqual([7, 8, 9], self.evaluate(get_next))
+ self.assertAllEqual([7, 8, 9], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -280,12 +280,12 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
- self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
- self.assertAllEqual([4, 5, 6], self.evaluate(get_next))
+ sess.run(init_op)
+ self.assertAllEqual([1, 2, 3], sess.run(get_next))
+ self.assertAllEqual([4, 5, 6], sess.run(get_next))
with self.assertRaisesOpError(r"element of shape \(3,\) was expected"):
sess.run(get_next)
- self.assertAllEqual([11, 12, 13], self.evaluate(get_next))
+ self.assertAllEqual([11, 12, 13], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -304,16 +304,16 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
- self.assertEqual((1, 2), self.evaluate(get_next))
- self.assertEqual((3, 4), self.evaluate(get_next))
+ sess.run(init_op)
+ self.assertEqual((1, 2), sess.run(get_next))
+ self.assertEqual((3, 4), sess.run(get_next))
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
sess.run(get_next)
with self.assertRaisesOpError(
r"The expected structure was \(tf\.int64, tf\.int64\)"):
sess.run(get_next)
- self.assertEqual((9, 10), self.evaluate(get_next))
+ self.assertEqual((9, 10), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -329,9 +329,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
- self.assertAllEqual(1, self.evaluate(get_next))
- self.assertAllEqual([2, 3], self.evaluate(get_next))
+ sess.run(init_op)
+ self.assertAllEqual(1, sess.run(get_next))
+ self.assertAllEqual([2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -349,9 +349,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
- self.assertAllEqual(0, self.evaluate(get_next))
- self.assertAllEqual(1, self.evaluate(get_next))
+ sess.run(init_op)
+ self.assertAllEqual(0, sess.run(get_next))
+ self.assertAllEqual(1, sess.run(get_next))
def testFromGeneratorDestructorCalled(self):
# Use an `Event` to signal that the generator has been deleted.
@@ -378,9 +378,9 @@
get_next = iterator.get_next()
with session.Session() as sess:
- self.evaluate(init_op)
- self.assertAllEqual(42, self.evaluate(get_next))
- self.assertAllEqual(42, self.evaluate(get_next))
+ sess.run(init_op)
+ self.assertAllEqual(42, sess.run(get_next))
+ self.assertAllEqual(42, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `GeneratorWrapper` object is destroyed when the
@@ -407,10 +407,10 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
expected = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
for x in expected:
- self.assertEqual(x, self.evaluate(get_next))
+ self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -436,13 +436,13 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
expected = [(0, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"),
(0, b"Hi!"), (1, b"Hi!"), (2, b"Hi!"), (3, b"Hi!")]
for x in expected:
- self.assertEqual(x, self.evaluate(get_next))
+ self.assertEqual(x, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -470,9 +470,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
- self.assertAllEqual(37, self.evaluate(get_next))
- self.assertAllEqual(37, self.evaluate(get_next))
+ sess.run(init_op)
+ self.assertAllEqual(37, sess.run(get_next))
+ self.assertAllEqual(37, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertTrue(event.is_set())
diff --git a/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py b/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py
new file mode 100644
index 0000000..d23ac0e
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py
@@ -0,0 +1,85 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.from_sparse_tensor_slices()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class FromSparseTensorSlicesTest(test_base.DatasetTestBase):
+
+ def testSkipEagerFromSparseTensorSlices(self):
+ """Test a dataset based on slices of a `tf.SparseTensor`."""
+ st = array_ops.sparse_placeholder(dtypes.float64)
+ iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st)
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+
+ with self.cached_session() as sess:
+ slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
+
+ # Test with sparse tensor in the appropriate order.
+ indices = np.array(
+ [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))])
+ values = np.array([val for s in slices for val in s])
+ dense_shape = np.array([len(slices), max(len(s) for s in slices) + 1])
+ sparse_feed = sparse_tensor.SparseTensorValue(indices, values,
+ dense_shape)
+ sess.run(init_op, feed_dict={st: sparse_feed})
+ for i, s in enumerate(slices):
+ results = sess.run(get_next)
+ self.assertAllEqual(s, results.values)
+ expected_indices = np.array(
+ [[j] for j in range(len(slices[i]))]).reshape([-1, 1])
+ self.assertAllEqual(expected_indices, results.indices)
+ self.assertAllEqual(dense_shape[1:], results.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test with sparse tensor in the reverse order, which is not
+ # currently supported.
+ reverse_order_indices = indices[::-1, :]
+ reverse_order_values = values[::-1]
+ sparse_feed = sparse_tensor.SparseTensorValue(
+ reverse_order_indices, reverse_order_values, dense_shape)
+ with self.assertRaises(errors.UnimplementedError):
+ sess.run(init_op, feed_dict={st: sparse_feed})
+
+ # Test with an empty sparse tensor.
+ empty_indices = np.empty((0, 4), dtype=np.int64)
+ empty_values = np.empty((0,), dtype=np.float64)
+ empty_dense_shape = [0, 4, 37, 9]
+ sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
+ empty_dense_shape)
+ sess.run(init_op, feed_dict={st: sparse_feed})
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py b/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py
new file mode 100644
index 0000000..9a480e5
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/from_tensor_slices_test.py
@@ -0,0 +1,177 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.from_tensor_slices()."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class FromTensorSlicesTest(test_base.DatasetTestBase):
+
+ def testFromTensorSlices(self):
+ """Test a dataset that represents the slices from a tuple of tensors."""
+ components = (
+ np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
+ np.array([[12], [13], [14], [15]]), 22),
+ np.array([37.0, 38.0, 39.0, 40.0])
+ )
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components)
+ get_next = self.getNext(dataset)
+
+ self.assertEqual([c.shape[1:] for c in components],
+ [shape for shape in dataset.output_shapes])
+
+ for i in range(4):
+ results = self.evaluate(get_next())
+ for component, result_component in zip(components, results):
+ self.assertAllEqual(component[i], result_component)
+ with self.assertRaises(errors.OutOfRangeError):
+ results = self.evaluate(get_next())
+
+ def testSkipEagerFromTensorSlicesSparse(self):
+ """Test a dataset that represents the slices from a tuple of tensors."""
+ components = (sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0], [1, 1], [2, 2]]),
+ values=np.array([1, 2, 3]),
+ dense_shape=np.array([3, 3])))
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components)
+
+ self.assertEqual(
+ [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components],
+ [shape for shape in dataset.output_shapes])
+
+ expected = [
+ (sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0]),
+ dense_shape=np.array([1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([1]),
+ dense_shape=np.array([3]))),
+ (sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0]),
+ dense_shape=np.array([1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[1]]),
+ values=np.array([2]),
+ dense_shape=np.array([3]))),
+ (sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0]),
+ dense_shape=np.array([1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[2]]),
+ values=np.array([3]),
+ dense_shape=np.array([3]))),
+ ]
+ self.assertDatasetProduces(dataset, expected_output=expected)
+
+ def testFromTensorSlicesMixed(self):
+ """Test a dataset that represents the slices from a tuple of tensors."""
+ components = (np.tile(np.array([[1], [2], [3]]), 20),
+ np.tile(np.array([[12], [13], [14]]), 22),
+ np.array([37.0, 38.0, 39.0]),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0], [1, 1], [2, 2]]),
+ values=np.array([1, 2, 3]),
+ dense_shape=np.array([3, 3])))
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components)
+ get_next = self.getNext(dataset)
+ self.assertEqual([
+ tensor_shape.TensorShape(c.dense_shape[1:])
+ if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components
+ ], [shape for shape in dataset.output_shapes])
+
+ expected = [
+ (sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0]),
+ dense_shape=np.array([1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([1]),
+ dense_shape=np.array([3]))),
+ (sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0]),
+ dense_shape=np.array([1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[1]]),
+ values=np.array([2]),
+ dense_shape=np.array([3]))),
+ (sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0]),
+ dense_shape=np.array([1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[2]]),
+ values=np.array([3]),
+ dense_shape=np.array([3]))),
+ ]
+ for i in range(3):
+ results = self.evaluate(get_next())
+ for component, result_component in zip(
+ (list(zip(*components[:3]))[i] + expected[i]), results):
+ if sparse_tensor.is_sparse(component):
+ self.assertSparseValuesEqual(component, result_component)
+ else:
+ self.assertAllEqual(component, result_component)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ def testFromTensorSlicesWithDict(self):
+ components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
+ dataset = dataset_ops.Dataset.from_tensor_slices(components)
+ get_next = self.getNext(dataset)
+
+ self.assertEqual(dtypes.int32, dataset.output_types["foo"])
+ self.assertEqual(dtypes.float32, dataset.output_types["bar"])
+ self.assertEqual((), dataset.output_shapes["foo"])
+ self.assertEqual((1,), dataset.output_shapes["bar"])
+
+ for i in range(3):
+ results = self.evaluate(get_next())
+ self.assertEqual(components["foo"][i], results["foo"])
+ self.assertEqual(components["bar"][i], results["bar"])
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/from_tensors_test.py b/tensorflow/python/data/kernel_tests/from_tensors_test.py
new file mode 100644
index 0000000..2857817
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/from_tensors_test.py
@@ -0,0 +1,258 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.from_tensors()."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class FromTensorsTest(test_base.DatasetTestBase):
+
+ def testFromTensors(self):
+ """Test a dataset that represents a single tuple of tensors."""
+ components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
+
+ dataset = dataset_ops.Dataset.from_tensors(components)
+
+ self.assertEqual([c.shape for c in components],
+ nest.flatten(dataset.output_shapes))
+
+ self.assertDatasetProduces(dataset, expected_output=[components])
+
+ def testSkipEagerFromTensorsSparse(self):
+ """Test a dataset that represents a single tuple of tensors."""
+ components = (sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0]),
+ dense_shape=np.array([1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0], [1, 1]]),
+ values=np.array([-1, 1]),
+ dense_shape=np.array([2, 2])))
+
+ dataset = dataset_ops.Dataset.from_tensors(components)
+
+ self.assertEqual(
+ [tensor_shape.TensorShape(c.dense_shape) for c in components],
+ [shape for shape in dataset.output_shapes])
+ self.assertDatasetProduces(dataset, expected_output=[components])
+
+ def testFromTensorsMixed(self):
+ """Test an dataset that represents a single tuple of tensors."""
+ components = (np.array(1), np.array([1, 2, 3]), np.array(37.0),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[0]]),
+ values=np.array([0]),
+ dense_shape=np.array([1])),
+ sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0], [1, 1]]),
+ values=np.array([-1, 1]),
+ dense_shape=np.array([2, 2])))
+
+ dataset = dataset_ops.Dataset.from_tensors(components)
+ self.assertEqual([
+ tensor_shape.TensorShape(c.dense_shape)
+ if sparse_tensor.is_sparse(c) else c.shape for c in components
+ ], [shape for shape in dataset.output_shapes])
+
+ self.assertDatasetProduces(dataset, expected_output=[components])
+
+ # pylint: disable=g-long-lambda,unnecessary-lambda
+ def testNestedStructure(self):
+ components = (np.array([1, 2, 3], dtype=np.int64),
+ (np.array([4., 5.]), np.array([6., 7.])),
+ np.array([8, 9, 10], dtype=np.int64))
+
+ dataset = dataset_ops.Dataset.from_tensors(components)
+ self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
+ dtypes.int64), dataset.output_types)
+ self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
+
+ dataset = dataset.shuffle(10, 10)
+ self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
+ dtypes.int64), dataset.output_types)
+ self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
+
+ dataset = dataset.repeat(-1)
+ self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
+ dtypes.int64), dataset.output_types)
+ self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
+
+ dataset = dataset.filter(lambda x, y, z: True)
+ self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
+ dtypes.int64), dataset.output_types)
+ self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
+
+ dataset = dataset.take(5)
+ self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64),
+ dtypes.int64), dataset.output_types)
+ self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes)
+
+ dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
+ self.assertEquals(((dtypes.int64, dtypes.int64),
+ (dtypes.float64, dtypes.float64)), dataset.output_types)
+ self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
+
+ dataset = dataset.flat_map(
+ lambda x, y: dataset_ops.Dataset.from_tensors(((x[0], x[1]),
+ (y[0], y[1])))
+ )
+ self.assertEquals(((dtypes.int64, dtypes.int64),
+ (dtypes.float64, dtypes.float64)), dataset.output_types)
+ self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes)
+
+ dataset = dataset.batch(32)
+ self.assertEquals(((dtypes.int64, dtypes.int64),
+ (dtypes.float64, dtypes.float64)), dataset.output_types)
+ self.assertEquals((([None, 3], [None, 3]), ([None, 2], [None, 2])),
+ nest.pack_sequence_as(dataset.output_shapes, [
+ s.as_list()
+ for s in nest.flatten(dataset.output_shapes)
+ ]))
+
+ # Define a separate set of components with matching leading
+ # dimension for the from-slices constructor.
+ components_for_slices = (np.array([1, 2, 3], dtype=np.int64),
+ (np.array([4., 5., 6.]), np.array([7., 8., 9.])),
+ np.array([10, 11, 12], dtype=np.int64))
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices)
+ self.assertEquals((dtypes.int64,
+ (dtypes.float64, dtypes.float64), dtypes.int64),
+ dataset.output_types)
+ self.assertEquals(([], ([], []), []), dataset.output_shapes)
+
+ # TODO(b/117581999): more specific shapes in eager mode.
+ def testSkipEagerNestedStructure(self):
+ components = (np.array([1, 2, 3], dtype=np.int64), (np.array([4., 5.]),
+ np.array([6., 7.])),
+ np.array([8, 9, 10], dtype=np.int64))
+
+ dataset = dataset_ops.Dataset.from_tensors(components)
+ dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1])))
+
+ dataset = dataset.flat_map(
+ lambda x, y: dataset_ops.Dataset.from_tensors(
+ ((x[0], x[1]), (y[0], y[1])))).batch(32)
+
+ get_next = self.getNext(dataset)
+ (w, x), (y, z) = get_next()
+ self.assertEquals(dtypes.int64, w.dtype)
+ self.assertEquals(dtypes.int64, x.dtype)
+ self.assertEquals(dtypes.float64, y.dtype)
+ self.assertEquals(dtypes.float64, z.dtype)
+ self.assertEquals([None, 3], w.shape.as_list())
+ self.assertEquals([None, 3], x.shape.as_list())
+ self.assertEquals([None, 2], y.shape.as_list())
+ self.assertEquals([None, 2], z.shape.as_list())
+
+ get_next = self.getNext(dataset)
+ (w, x), (y, z) = get_next()
+ self.assertEquals(dtypes.int64, w.dtype)
+ self.assertEquals(dtypes.int64, x.dtype)
+ self.assertEquals(dtypes.float64, y.dtype)
+ self.assertEquals(dtypes.float64, z.dtype)
+ self.assertEquals([None, 3], w.shape.as_list())
+ self.assertEquals([None, 3], x.shape.as_list())
+ self.assertEquals([None, 2], y.shape.as_list())
+ self.assertEquals([None, 2], z.shape.as_list())
+
+ def testNestedDict(self):
+ components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
+ dataset = dataset_ops.Dataset.from_tensors(components)
+ self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"])
+ self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"])
+ self.assertEquals(dtypes.int32, dataset.output_types["b"])
+ self.assertEquals([], dataset.output_shapes["a"]["aa"])
+ self.assertEquals([2], dataset.output_shapes["a"]["ab"])
+ self.assertEquals([3], dataset.output_shapes["b"])
+
+ def testNonSequenceNestedStructure(self):
+ components = np.array([1, 2, 3], dtype=np.int64)
+
+ dataset = dataset_ops.Dataset.from_tensors(components)
+ self.assertEquals(dtypes.int64, dataset.output_types)
+ self.assertEquals([3], dataset.output_shapes)
+
+ dataset = dataset.filter(
+ lambda x: math_ops.reduce_all(math_ops.equal(x, components)))
+ self.assertEquals(dtypes.int64, dataset.output_types)
+ self.assertEquals([3], dataset.output_shapes)
+
+ dataset = dataset.map(lambda x: array_ops.stack([x, x]))
+ self.assertEquals(dtypes.int64, dataset.output_types)
+ self.assertEquals([2, 3], dataset.output_shapes)
+
+ dataset = dataset.flat_map(
+ lambda x: dataset_ops.Dataset.from_tensor_slices(x))
+ self.assertEquals(dtypes.int64, dataset.output_types)
+ self.assertEquals([3], dataset.output_shapes)
+
+ get_next = self.getNext(dataset)
+ self.assertEquals(dtypes.int64, get_next().dtype)
+ self.assertEquals([3], get_next().shape)
+
+ def testSkipEagerSplitPipelineFailsWithPlacementError(self):
+ with session.Session(
+ target="",
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
+
+ dataset = dataset_ops.Dataset.from_tensors(0)
+
+ # Define a pipeline that attempts to use variables on two
+ # different devices.
+ #
+ # Initialize the variables before creating to iterator, to avoid the
+ # placement algorithm overriding the DT_RESOURCE colocation constraints.
+ with ops.device("/cpu:0"):
+ var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
+ dataset = dataset.map(lambda x: x + var_0.read_value())
+ sess.run(var_0.initializer)
+
+ with ops.device("/cpu:1"):
+ var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
+ dataset = dataset.map(lambda x: x + var_1.read_value())
+ sess.run(var_1.initializer)
+
+ iterator = dataset.make_initializable_iterator()
+ sess.run(iterator.initializer)
+
+ with self.assertRaisesRegexp(
+ errors.FailedPreconditionError,
+ "Error while reading resource variable Variable"):
+ sess.run(iterator.get_next())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py
deleted file mode 100644
index d089b49..0000000
--- a/tensorflow/python/data/kernel_tests/inputs_test.py
+++ /dev/null
@@ -1,149 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import readers
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.platform import test
-
-
-class InputsTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @staticmethod
- def make_apply_fn(dataset):
-
- def apply_fn(dataset):
-
- def _apply_fn(dataset):
- return dataset.cache()
-
- return dataset.apply(_apply_fn)
-
- return apply_fn
-
- @staticmethod
- def make_gen():
-
- def gen():
- yield 42
-
- return gen
-
- @staticmethod
- def make_interleave_fn(dataset, num_parallel_calls=None):
-
- def interleave_fn(dataset):
- return dataset.interleave(
- lambda x: dataset_ops.Dataset.range(0),
- cycle_length=2,
- num_parallel_calls=num_parallel_calls)
-
- return interleave_fn
-
- @parameterized.named_parameters(
- ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
- ("FromGenerator",
- dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
- 1),
- ("FromSparseTensorSlices",
- dataset_ops.Dataset.from_sparse_tensor_slices(
- sparse_tensor.SparseTensor(
- indices=np.array([[0, 0], [1, 0], [2, 0]]),
- values=np.array([0, 0, 0]),
- dense_shape=np.array([3, 1])))),
- ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
- ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
- ("Range", dataset_ops.Dataset.range(10)),
- ("TextLine", readers.TextLineDataset("")),
- ("TFRecord", readers.TFRecordDataset(""), 1),
- )
- def testDatasetSourceInputs(self, dataset, num_inputs=0):
- self.assertEqual(num_inputs, len(dataset._inputs()))
-
- @parameterized.named_parameters(
- ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
- dataset_ops.Dataset.range(0)),
- ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
- ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
- ("Filter", lambda x: x.filter(lambda x: True),
- dataset_ops.Dataset.range(0)),
- ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
- dataset_ops.Dataset.range(0)),
- ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
- dataset_ops.Dataset.range(0)),
- ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
- ("PaddedBatch", lambda x: x.padded_batch(10, []),
- dataset_ops.Dataset.range(0)),
- ("ParallelInterleave",
- make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
- dataset_ops.Dataset.range(0)),
- ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
- dataset_ops.Dataset.range(0)),
- ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
- ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
- ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
- ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
- ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
- )
- def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
- self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
-
- @parameterized.named_parameters(
- ("Concatenate", lambda x, y: x.concatenate(y),
- dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
- def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
- self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
-
- @parameterized.named_parameters(
- ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
- ("ZipNest", dataset_ops.Dataset.zip,
- (dataset_ops.Dataset.range(0),
- (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
- ("ZipTuple", dataset_ops.Dataset.zip,
- (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
- def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
- self.assertEqual(
- nest.flatten(input_datasets),
- dataset_fn(input_datasets)._inputs())
-
- def testCollectInputs(self):
- ds1 = dataset_ops.Dataset.range(0)
- ds2 = ds1.concatenate(ds1)
- ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
-
- inputs = []
- queue = [ds3]
- while queue:
- ds = queue[0]
- queue = queue[1:]
- queue.extend(ds._inputs())
- inputs.append(ds)
-
- self.assertEqual(5, inputs.count(ds1))
- self.assertEqual(2, inputs.count(ds2))
- self.assertEqual(1, inputs.count(ds3))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_test.py
similarity index 85%
rename from tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
rename to tensorflow/python/data/kernel_tests/interleave_test.py
index 56434d6..cd1d850 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.Dataset.interleave()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -27,6 +27,7 @@
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
@@ -133,7 +134,8 @@
return [[value] * value for value in np.tile(values, count)]
-class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+@test_util.run_all_in_graph_and_eager_modes
+class InterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", [4, 5, 6], 1, 1, [
@@ -191,16 +193,11 @@
count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
cycle_length, block_length, num_parallel_calls)
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- for expected_element in _interleave(
- _repeat(input_values, count), cycle_length, block_length):
- self.assertEqual(expected_element, self.evaluate(get_next))
-
- for _ in range(2):
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ expected_output = [
+ element for element in _interleave(
+ _repeat(input_values, count), cycle_length, block_length)
+ ]
+ self.assertDatasetProduces(dataset, expected_output)
@parameterized.named_parameters(
("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None),
@@ -223,17 +220,16 @@
lambda x: array_ops.check_numerics(x, "message")).interleave(
dataset_ops.Dataset.from_tensors, cycle_length, block_length,
num_parallel_calls)
- get_next = dataset.make_one_shot_iterator().get_next()
+ get_next = self.getNext(dataset)
- with self.cached_session() as sess:
- for value in input_values:
- if np.isnan(value):
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
- else:
- self.assertEqual(value, self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ for value in input_values:
+ if np.isnan(value):
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(get_next())
+ else:
+ self.assertEqual(value, self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
def testInterleaveSparse(self):
@@ -245,18 +241,17 @@
return dataset_ops.Dataset.from_tensor_slices(
sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
- iterator = (
- dataset_ops.Dataset.range(10).map(_map_fn).interleave(
- _interleave_fn, cycle_length=1).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- for j in range(2):
- expected = [i, 0] if j % 2 == 0 else [0, -i]
- self.assertAllEqual(expected, self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
+ dataset = dataset_ops.Dataset.range(10).map(_map_fn).interleave(
+ _interleave_fn, cycle_length=1)
+ get_next = self.getNext(dataset)
+ for i in range(10):
+ for j in range(2):
+ expected = [i, 0] if j % 2 == 0 else [0, -i]
+ self.assertAllEqual(expected, self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
@parameterized.named_parameters(
("1", np.int64([4, 5, 6]), 2, 1, 1),
@@ -269,8 +264,8 @@
("8", np.int64([4, 0, 6]), 2, 3, 1),
("9", np.int64([4, 0, 6]), 2, 3, 2),
)
- def testSloppyInterleaveInOrder(self, input_values, cycle_length,
- block_length, num_parallel_calls):
+ def testSkipEagerSloppyInterleaveInOrder(self, input_values, cycle_length,
+ block_length, num_parallel_calls):
get_next, coordination_events = _make_coordinated_sloppy_dataset(
input_values, cycle_length, block_length, num_parallel_calls)
config = config_pb2.ConfigProto(
@@ -281,7 +276,7 @@
_repeat(input_values, 2), cycle_length, block_length):
coordination_events[expected_element].set()
self.assertEqual(expected_element * expected_element,
- sess.run(get_next))
+ self.evaluate(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -291,8 +286,8 @@
("3", np.int64([4, 5, 6]), 3, 2, 3),
("4", np.int64([4, 0, 6]), 2, 3, 2),
)
- def testSloppyInterleaveOutOfOrder(self, input_values, cycle_length,
- block_length, num_parallel_calls):
+ def testSkipEagerSloppyInterleaveOutOfOrder(self, input_values, cycle_length,
+ block_length, num_parallel_calls):
get_next, coordination_events = _make_coordinated_sloppy_dataset(
input_values, cycle_length, block_length, num_parallel_calls)
config = config_pb2.ConfigProto(
diff --git a/tensorflow/python/data/kernel_tests/iterator_checkpoint_test.py b/tensorflow/python/data/kernel_tests/iterator_checkpoint_test.py
new file mode 100644
index 0000000..fc4164c
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/iterator_checkpoint_test.py
@@ -0,0 +1,129 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Checkpoint tests for `tf.data.Iterator`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class IteratorCheckpointingTest(test_base.DatasetTestBase):
+
+ def testSaveRestoreOneShotIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(
+ math_ops.square).batch(2)
+ iterator = iter(dataset) if context.executing_eagerly(
+ ) else dataset.make_one_shot_iterator()
+ get_next = iterator.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ self.assertAllEqual([1, 4], get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual([9, 16], get_next())
+ self.assertAllEqual([25, 36], get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
+
+ def testSaveRestoreMultipleIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+ dataset = dataset.map(math_ops.square).batch(2)
+ iterator_1 = iter(dataset) if context.executing_eagerly(
+ ) else dataset.make_one_shot_iterator()
+ get_next_1 = iterator_1.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_1.get_next())
+ iterator_2 = iter(dataset) if context.executing_eagerly(
+ ) else dataset.make_one_shot_iterator()
+ get_next_2 = iterator_2.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_2.get_next())
+ dataset_2 = dataset_ops.Dataset.range(10)
+ iterator_3 = iter(dataset_2) if context.executing_eagerly(
+ ) else dataset_2.make_one_shot_iterator()
+ get_next_3 = iterator_3.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator_3.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(
+ iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
+ self.assertAllEqual([1, 4], get_next_1())
+ self.assertAllEqual(0, get_next_3())
+ self.assertAllEqual(1, get_next_3())
+ self.assertAllEqual(2, get_next_3())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual([9, 16], get_next_2())
+ self.assertAllEqual(3, get_next_3())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual([9, 16], get_next_1())
+ self.assertAllEqual([1, 4], get_next_2())
+ self.assertAllEqual(3, get_next_3())
+
+ def testRestoreExhaustedIterator(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.range(3)
+ iterator = iter(dataset) if context.executing_eagerly(
+ ) else dataset.make_one_shot_iterator()
+ get_next = iterator.get_next if context.executing_eagerly(
+ ) else functools.partial(self.evaluate, iterator.get_next())
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ self.assertAllEqual(0, get_next())
+ self.assertAllEqual(1, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ self.assertAllEqual(2, get_next())
+ checkpoint.restore(save_path).run_restore_ops()
+ self.assertAllEqual(2, get_next())
+ save_path = checkpoint.save(checkpoint_prefix)
+ checkpoint.restore(save_path).run_restore_ops()
+ with self.assertRaises(errors.OutOfRangeError):
+ get_next()
+
+ def testRestoreInReconstructedIteratorInitializable(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ dataset = dataset_ops.Dataset.range(10)
+ iterator = iter(dataset) if context.executing_eagerly(
+ ) else dataset.make_initializable_iterator()
+ get_next = iterator.get_next
+ checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
+ for i in range(5):
+ checkpoint.restore(
+ checkpoint_management.latest_checkpoint(
+ checkpoint_directory)).initialize_or_restore()
+ for j in range(2):
+ self.assertEqual(i * 2 + j, self.evaluate(get_next()))
+ checkpoint.save(file_prefix=checkpoint_prefix)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/python/data/kernel_tests/iterator_cluster_test.py
similarity index 95%
rename from tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
rename to tensorflow/python/data/kernel_tests/iterator_cluster_test.py
index cb38728..c1f856e 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_cluster_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops that need test_util."""
+"""Tests for `tf.data.Iterator` using distributed sessions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -57,7 +57,7 @@
with session.Session(worker[0].target) as sess:
with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(get_next_op)
+ sess.run(get_next_op)
def _testRemoteIteratorHelper(self, device0, device1, target):
with ops.device(device1):
@@ -134,12 +134,12 @@
get_next = iterator.get_next()
with session.Session(worker[0].target) as sess:
- self.evaluate(table.initializer)
- self.evaluate(init_op)
- self.assertAllEqual([0, 0, -1, 1, 2], self.evaluate(get_next))
+ sess.run(table.initializer)
+ sess.run(init_op)
+ self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))
with session.Session(worker[0].target) as sess:
- self.assertAllEqual([2, 0], self.evaluate(get_next))
+ self.assertAllEqual([2, 0], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -166,7 +166,7 @@
get_next = iterator.get_next()
with session.Session(worker[0].target) as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for _ in range(3):
sess.run(get_next)
diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py
similarity index 85%
rename from tensorflow/python/data/kernel_tests/iterator_ops_test.py
rename to tensorflow/python/data/kernel_tests/iterator_test.py
index 405d94d..de95a53 100644
--- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/iterator_test.py
@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.Iterator`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import os
import warnings
@@ -50,9 +49,7 @@
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
-from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
@@ -97,7 +94,7 @@
with self.cached_session() as sess:
for _ in range(14):
for i in range(7):
- result = self.evaluate(get_next)
+ result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@@ -123,7 +120,7 @@
with self.cached_session() as sess:
for _ in range(14):
for i in range(7):
- result = self.evaluate(get_next)
+ result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@@ -159,7 +156,7 @@
for _ in range(14):
for i in range(7):
- result = self.evaluate(get_next)
+ result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@@ -175,7 +172,7 @@
config = config_pb2.ConfigProto(
inter_op_parallelism_threads=1, use_per_session_threads=True)
with session.Session(config=config) as sess:
- self.assertAllEqual([1, 4, 9], self.evaluate(next_element))
+ self.assertAllEqual([1, 4, 9], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@@ -254,15 +251,15 @@
get_next = iterator.get_next()
with session.Session(server.target) as sess:
- self.evaluate(init_op)
- results = self.evaluate(get_next)
+ sess.run(init_op)
+ results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Re-initialize the iterator in the first session.
- self.evaluate(init_op)
+ sess.run(init_op)
with ops.Graph().as_default():
# Re-define the iterator manually, without defining any of the
@@ -277,7 +274,7 @@
with session.Session(server.target) as sess:
# Use the iterator without re-initializing in the second session.
- results = self.evaluate(get_next)
+ results = sess.run(get_next)
for component, result_component in zip(components, results):
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
@@ -317,20 +314,20 @@
sess.run(get_next)
# Initialize with one dataset.
- self.evaluate(dataset_3_init_op)
- self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
+ sess.run(dataset_3_init_op)
+ self.assertAllEqual([1, 2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Initialize with a different dataset.
- self.evaluate(dataset_4_init_op)
- self.assertAllEqual([4, 5, 6, 7], self.evaluate(get_next))
+ sess.run(dataset_4_init_op)
+ self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Reinitialize with the first dataset.
- self.evaluate(dataset_3_init_op)
- self.assertAllEqual([1, 2, 3], self.evaluate(get_next))
+ sess.run(dataset_3_init_op)
+ self.assertAllEqual([1, 2, 3], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -348,7 +345,7 @@
g, output_types=dtypes.int64)
sess.run(iterator.make_initializer(dataset_1))
for expected in range(10):
- self.assertEqual(expected, self.evaluate(next_element))
+ self.assertEqual(expected, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@@ -356,7 +353,7 @@
g, output_types=dtypes.int64)
sess.run(iterator.make_initializer(dataset_2))
for expected in range(10):
- self.assertEqual(expected, self.evaluate(next_element))
+ self.assertEqual(expected, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@@ -679,10 +676,10 @@
n = itr.get_next()
with session.Session(s3.target, config=config) as sess:
- self.evaluate(itr.initializer)
+ sess.run(itr.initializer)
expected_values = worker_devices
for expected in expected_values:
- self.assertEqual((compat.as_bytes(expected),), self.evaluate(n))
+ self.assertEqual((compat.as_bytes(expected),), sess.run(n))
with self.assertRaises(errors.OutOfRangeError):
sess.run(n)
@@ -786,8 +783,8 @@
with ops.Graph().as_default() as g:
init_op, _, save_op, _ = _build_range_dataset_graph()
with self.session(graph=g) as sess:
- self.evaluate(init_op)
- self.evaluate(save_op)
+ sess.run(init_op)
+ sess.run(save_op)
# Attempt to restore the saved iterator into an IteratorResource of
# incompatible type. An iterator of RangeDataset has output type int64,
@@ -798,7 +795,7 @@
_, _, _, restore_op = _build_reader_dataset_graph()
with self.session(graph=g) as sess:
with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(restore_op)
+ sess.run(restore_op)
def testRepeatedGetNextWarning(self):
iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
@@ -863,95 +860,5 @@
self.assertEqual("overridden_name", next_element.op.name)
-class IteratorCheckpointingTest(test.TestCase):
-
- @test_util.run_in_graph_and_eager_modes
- def testSaveRestoreOneShotIterator(self):
- checkpoint_directory = self.get_temp_dir()
- checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(
- math_ops.square).batch(2)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next if context.executing_eagerly(
- ) else functools.partial(self.evaluate, iterator.get_next())
- checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
- self.assertAllEqual([1, 4], get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual([9, 16], get_next())
- self.assertAllEqual([25, 36], get_next())
- checkpoint.restore(save_path).run_restore_ops()
- self.assertAllEqual([9, 16], get_next())
- self.assertAllEqual([25, 36], get_next())
- with self.assertRaises(errors.OutOfRangeError):
- get_next()
-
- @test_util.run_in_graph_and_eager_modes
- def testSaveRestoreMultipleIterator(self):
- checkpoint_directory = self.get_temp_dir()
- checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- dataset = dataset_ops.Dataset.from_tensor_slices(
- [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
- dataset = dataset.map(math_ops.square).batch(2)
- iterator_1 = dataset.make_one_shot_iterator()
- get_next_1 = iterator_1.get_next if context.executing_eagerly(
- ) else functools.partial(self.evaluate, iterator_1.get_next())
- iterator_2 = dataset.make_one_shot_iterator()
- get_next_2 = iterator_2.get_next if context.executing_eagerly(
- ) else functools.partial(self.evaluate, iterator_2.get_next())
- dataset_2 = dataset_ops.Dataset.range(10)
- iterator_3 = dataset_2.make_one_shot_iterator()
- get_next_3 = iterator_3.get_next if context.executing_eagerly(
- ) else functools.partial(self.evaluate, iterator_3.get_next())
- checkpoint = checkpointable_utils.Checkpoint(
- iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
- self.assertAllEqual([1, 4], get_next_1())
- self.assertAllEqual(0, get_next_3())
- self.assertAllEqual(1, get_next_3())
- self.assertAllEqual(2, get_next_3())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual([1, 4], get_next_2())
- self.assertAllEqual([9, 16], get_next_2())
- self.assertAllEqual(3, get_next_3())
- checkpoint.restore(save_path).run_restore_ops()
- self.assertAllEqual([9, 16], get_next_1())
- self.assertAllEqual([1, 4], get_next_2())
- self.assertAllEqual(3, get_next_3())
-
- @test_util.run_in_graph_and_eager_modes
- def testRestoreExhaustedIterator(self):
- checkpoint_directory = self.get_temp_dir()
- checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- dataset = dataset_ops.Dataset.range(3)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next if context.executing_eagerly(
- ) else functools.partial(self.evaluate, iterator.get_next())
- checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
- self.assertAllEqual(0, get_next())
- self.assertAllEqual(1, get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- self.assertAllEqual(2, get_next())
- checkpoint.restore(save_path).run_restore_ops()
- self.assertAllEqual(2, get_next())
- save_path = checkpoint.save(checkpoint_prefix)
- checkpoint.restore(save_path).run_restore_ops()
- with self.assertRaises(errors.OutOfRangeError):
- get_next()
-
- def testRestoreInReconstructedIteratorInitializable(self):
- checkpoint_directory = self.get_temp_dir()
- checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- dataset = dataset_ops.Dataset.range(10)
- iterator = dataset.make_initializable_iterator()
- get_next = iterator.get_next()
- checkpoint = checkpointable_utils.Checkpoint(iterator=iterator)
- for i in range(5):
- with self.cached_session() as sess:
- checkpoint.restore(checkpoint_management.latest_checkpoint(
- checkpoint_directory)).initialize_or_restore(sess)
- for j in range(2):
- self.assertEqual(i * 2 + j, self.evaluate(get_next))
- checkpoint.save(file_prefix=checkpoint_prefix)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
deleted file mode 100644
index ac6fbab..0000000
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ /dev/null
@@ -1,291 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from os import path
-import shutil
-import tempfile
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class ListFilesDatasetOpTest(test_base.DatasetTestBase):
-
- def setUp(self):
- self.tmp_dir = tempfile.mkdtemp()
-
- def tearDown(self):
- shutil.rmtree(self.tmp_dir, ignore_errors=True)
-
- def _touchTempFiles(self, filenames):
- for filename in filenames:
- open(path.join(self.tmp_dir, filename), 'a').close()
-
- def testEmptyDirectory(self):
- dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
- with self.cached_session() as sess:
- itr = dataset.make_one_shot_iterator()
- next_element = itr.get_next()
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testSimpleDirectory(self):
- filenames = ['a', 'b', 'c']
- self._touchTempFiles(filenames)
-
- dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
- with self.cached_session() as sess:
- itr = dataset.make_one_shot_iterator()
- next_element = itr.get_next()
-
- full_filenames = []
- produced_filenames = []
- for filename in filenames:
- full_filenames.append(
- compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(next_element)))
- self.assertItemsEqual(full_filenames, produced_filenames)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
-
- def testSimpleDirectoryNotShuffled(self):
- filenames = ['b', 'c', 'a']
- self._touchTempFiles(filenames)
-
- dataset = dataset_ops.Dataset.list_files(
- path.join(self.tmp_dir, '*'), shuffle=False)
- with self.cached_session() as sess:
- itr = dataset.make_one_shot_iterator()
- next_element = itr.get_next()
-
- for filename in sorted(filenames):
- self.assertEqual(compat.as_bytes(path.join(self.tmp_dir, filename)),
- sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
-
- def testFixedSeedResultsInRepeatableOrder(self):
- filenames = ['a', 'b', 'c']
- self._touchTempFiles(filenames)
-
- dataset = dataset_ops.Dataset.list_files(
- path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
- with self.cached_session() as sess:
- itr = dataset.make_initializable_iterator()
- next_element = itr.get_next()
-
- full_filenames = [compat.as_bytes(path.join(self.tmp_dir, filename))
- for filename in filenames]
-
- all_produced_filenames = []
- for _ in range(3):
- produced_filenames = []
- self.evaluate(itr.initializer)
- try:
- while True:
- produced_filenames.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
- all_produced_filenames.append(produced_filenames)
-
- # Each run should produce the same set of filenames, which may be
- # different from the order of `full_filenames`.
- self.assertItemsEqual(full_filenames, all_produced_filenames[0])
- # However, the different runs should produce filenames in the same order
- # as each other.
- self.assertEqual(all_produced_filenames[0], all_produced_filenames[1])
- self.assertEqual(all_produced_filenames[0], all_produced_filenames[2])
-
- def testEmptyDirectoryInitializer(self):
- filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
- dataset = dataset_ops.Dataset.list_files(filename_placeholder)
-
- with self.cached_session() as sess:
- itr = dataset.make_initializable_iterator()
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError, 'No files matched pattern: '):
- sess.run(
- itr.initializer,
- feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
-
- def testSimpleDirectoryInitializer(self):
- filenames = ['a', 'b', 'c']
- self._touchTempFiles(filenames)
-
- filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
- dataset = dataset_ops.Dataset.list_files(filename_placeholder)
-
- with self.cached_session() as sess:
- itr = dataset.make_initializable_iterator()
- next_element = itr.get_next()
- sess.run(
- itr.initializer,
- feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')})
-
- full_filenames = []
- produced_filenames = []
- for filename in filenames:
- full_filenames.append(
- compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(next_element)))
-
- self.assertItemsEqual(full_filenames, produced_filenames)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
-
- def testFileSuffixes(self):
- filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
- self._touchTempFiles(filenames)
-
- filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
- dataset = dataset_ops.Dataset.list_files(filename_placeholder)
-
- with self.cached_session() as sess:
- itr = dataset.make_initializable_iterator()
- next_element = itr.get_next()
- sess.run(
- itr.initializer,
- feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')})
-
- full_filenames = []
- produced_filenames = []
- for filename in filenames[1:-1]:
- full_filenames.append(
- compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(next_element)))
- self.assertItemsEqual(full_filenames, produced_filenames)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
-
- def testFileMiddles(self):
- filenames = ['a.txt', 'b.py', 'c.pyc']
- self._touchTempFiles(filenames)
-
- filename_placeholder = array_ops.placeholder(dtypes.string, shape=[])
- dataset = dataset_ops.Dataset.list_files(filename_placeholder)
-
- with self.cached_session() as sess:
- itr = dataset.make_initializable_iterator()
- next_element = itr.get_next()
- sess.run(
- itr.initializer,
- feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')})
-
- full_filenames = []
- produced_filenames = []
- for filename in filenames[1:]:
- full_filenames.append(
- compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(next_element)))
-
- self.assertItemsEqual(full_filenames, produced_filenames)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
-
- def testNoShuffle(self):
- filenames = ['a', 'b', 'c']
- self._touchTempFiles(filenames)
-
- # Repeat the list twice and ensure that the order is the same each time.
- # NOTE(mrry): This depends on an implementation detail of `list_files()`,
- # which is that the list of files is captured when the iterator is
- # initialized. Otherwise, or if e.g. the iterator were initialized more than
- # once, it's possible that the non-determinism of `tf.matching_files()`
- # would cause this test to fail. However, it serves as a useful confirmation
- # that the `shuffle=False` argument is working as intended.
- # TODO(b/73959787): Provide some ordering guarantees so that this test is
- # more meaningful.
- dataset = dataset_ops.Dataset.list_files(
- path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
- with self.cached_session() as sess:
- itr = dataset.make_one_shot_iterator()
- next_element = itr.get_next()
-
- full_filenames = []
- produced_filenames = []
- for filename in filenames * 2:
- full_filenames.append(
- compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(next_element)))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
- self.assertItemsEqual(full_filenames, produced_filenames)
- self.assertEqual(produced_filenames[:len(filenames)],
- produced_filenames[len(filenames):])
-
- def testMultiplePatternsAsList(self):
- filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
- self._touchTempFiles(filenames)
-
- patterns = [path.join(self.tmp_dir, pat) for pat in ['*.py', '*.txt']]
- dataset = dataset_ops.Dataset.list_files(patterns)
- with self.cached_session() as sess:
- itr = dataset.make_one_shot_iterator()
- next_element = itr.get_next()
-
- full_filenames = []
- produced_filenames = []
- for filename in filenames[:-1]:
- full_filenames.append(
- compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(next_element)))
- self.assertItemsEqual(full_filenames, produced_filenames)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
-
- def testMultiplePatternsAsTensor(self):
- filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
- self._touchTempFiles(filenames)
-
- filename_placeholder = array_ops.placeholder(
- dtypes.string, shape=[
- 2,
- ])
- dataset = dataset_ops.Dataset.list_files(filename_placeholder)
-
- with self.cached_session() as sess:
- itr = dataset.make_initializable_iterator()
- next_element = itr.get_next()
- patterns = [path.join(self.tmp_dir, pat) for pat in ['*.py', '*.txt']]
- sess.run(itr.initializer, feed_dict={filename_placeholder: patterns})
-
- full_filenames = []
- produced_filenames = []
- for filename in filenames[:-1]:
- full_filenames.append(
- compat.as_bytes(path.join(self.tmp_dir, filename)))
- produced_filenames.append(compat.as_bytes(sess.run(next_element)))
- self.assertItemsEqual(full_filenames, produced_filenames)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(itr.get_next())
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/list_files_test.py b/tensorflow/python/data/kernel_tests/list_files_test.py
new file mode 100644
index 0000000..26c5360
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/list_files_test.py
@@ -0,0 +1,213 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.list_files()`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from os import path
+import shutil
+import tempfile
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class ListFilesTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ self.tmp_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.tmp_dir, ignore_errors=True)
+
+ def _touchTempFiles(self, filenames):
+ for filename in filenames:
+ open(path.join(self.tmp_dir, filename), 'a').close()
+
+ # Note: eager mode fails in assertion error same as initializer in graph mode.
+ def testSkipEagerEmptyDirectory(self):
+ dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
+ self.assertDatasetProduces(dataset, expected_output=[])
+
+ def testSimpleDirectory(self):
+ filenames = ['a', 'b', 'c']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[
+ compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in filenames
+ ],
+ assert_items_equal=True)
+
+ def testSimpleDirectoryNotShuffled(self):
+ filenames = ['b', 'c', 'a']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(
+ path.join(self.tmp_dir, '*'), shuffle=False)
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[
+ compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in sorted(filenames)
+ ])
+
+ def testFixedSeedResultsInRepeatableOrder(self):
+ filenames = ['a', 'b', 'c']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(
+ path.join(self.tmp_dir, '*'), shuffle=True, seed=37)
+
+ full_filenames = [compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in filenames]
+
+ all_produced_filenames = []
+ for _ in range(3):
+ produced_filenames = []
+ next_element = self.getNext(dataset, requires_initialization=True)
+ try:
+ while True:
+ produced_filenames.append(self.evaluate(next_element()))
+ except errors.OutOfRangeError:
+ pass
+ all_produced_filenames.append(produced_filenames)
+
+ # Each run should produce the same set of filenames, which may be
+ # different from the order of `full_filenames`.
+ self.assertItemsEqual(full_filenames, all_produced_filenames[0])
+ # However, the different runs should produce filenames in the same order
+ # as each other.
+ self.assertEqual(all_produced_filenames[0], all_produced_filenames[1])
+ self.assertEqual(all_produced_filenames[0], all_produced_filenames[2])
+
+ # TODO(b/117581999): eager mode assertion fail wrapped, debug.
+ def tesSkipEagerEmptyDirectoryInitializer(self):
+ dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
+ self.assertDatasetProduces(
+ dataset,
+ expected_error=(errors.InvalidArgumentError,
+ 'No files matched pattern'),
+ requires_initialization=True)
+
+ def testSimpleDirectoryInitializer(self):
+ filenames = ['a', 'b', 'c']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*'))
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[
+ compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in filenames
+ ],
+ assert_items_equal=True)
+
+ def testFileSuffixes(self):
+ filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*.py'))
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[
+ compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in filenames[1:-1]
+ ],
+ assert_items_equal=True)
+
+ def testFileMiddles(self):
+ filenames = ['a.txt', 'b.py', 'c.pyc']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*.py*'))
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[
+ compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in filenames[1:]
+ ],
+ assert_items_equal=True)
+
+ def testNoShuffle(self):
+ filenames = ['a', 'b', 'c']
+ self._touchTempFiles(filenames)
+
+ # Repeat the list twice and ensure that the order is the same each time.
+ # NOTE(mrry): This depends on an implementation detail of `list_files()`,
+ # which is that the list of files is captured when the iterator is
+ # initialized. Otherwise, or if e.g. the iterator were initialized more than
+ # once, it's possible that the non-determinism of `tf.matching_files()`
+ # would cause this test to fail. However, it serves as a useful confirmation
+ # that the `shuffle=False` argument is working as intended.
+ # TODO(b/73959787): Provide some ordering guarantees so that this test is
+ # more meaningful.
+ dataset = dataset_ops.Dataset.list_files(
+ path.join(self.tmp_dir, '*'), shuffle=False).repeat(2)
+ next_element = self.getNext(dataset)
+
+ full_filenames = []
+ produced_filenames = []
+ for filename in filenames * 2:
+ full_filenames.append(compat.as_bytes(path.join(self.tmp_dir, filename)))
+ produced_filenames.append(compat.as_bytes(self.evaluate(next_element())))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next_element())
+ self.assertItemsEqual(full_filenames, produced_filenames)
+ self.assertEqual(produced_filenames[:len(filenames)],
+ produced_filenames[len(filenames):])
+
+ def testMultiplePatternsAsList(self):
+ filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
+ self._touchTempFiles(filenames)
+
+ patterns = [path.join(self.tmp_dir, pat) for pat in ['*.py', '*.txt']]
+ dataset = dataset_ops.Dataset.list_files(patterns)
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[
+ compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in filenames[:-1]
+ ],
+ assert_items_equal=True)
+
+ def testMultiplePatternsAsTensor(self):
+ filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc']
+ self._touchTempFiles(filenames)
+
+ dataset = dataset_ops.Dataset.list_files(
+ [path.join(self.tmp_dir, pat) for pat in ['*.py', '*.txt']])
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[
+ compat.as_bytes(path.join(self.tmp_dir, filename))
+ for filename in filenames[:-1]
+ ],
+ assert_items_equal=True)
+
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_test.py
similarity index 84%
rename from tensorflow/python/data/kernel_tests/map_dataset_op_test.py
rename to tensorflow/python/data/kernel_tests/map_test.py
index 8f7a19d..a9c4d79 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_test.py
@@ -19,7 +19,6 @@
from collections import namedtuple
import threading
-import time
import warnings
from absl.testing import parameterized
@@ -27,7 +26,6 @@
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
@@ -114,7 +112,7 @@
sess.run(init_op, feed_dict={count: 14})
for _ in range(14):
for i in range(7):
- result = self.evaluate(get_next)
+ result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@@ -185,7 +183,7 @@
output_buffer_size: output_buffer_size_val})
for _ in range(14):
for i in range(7):
- result = self.evaluate(get_next)
+ result = sess.run(get_next)
for component, result_component in zip(components, result):
self.assertAllEqual(component[i]**2, result_component)
with self.assertRaises(errors.OutOfRangeError):
@@ -242,7 +240,7 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -257,7 +255,7 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for _ in range(3):
sess.run(get_next)
@@ -272,7 +270,7 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for _ in range(3):
sess.run(get_next)
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
@@ -293,7 +291,7 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for _ in range(3):
sess.run(get_next)
# The 4th element is NaN, so `array_ops.check_numerics()` should fail.
@@ -325,10 +323,10 @@
with ops.Graph().as_default() as g:
captured_init_op, init_op, get_next = _build_graph()
with self.session(graph=g) as sess:
- self.evaluate(captured_init_op)
- self.evaluate(init_op)
+ sess.run(captured_init_op)
+ sess.run(init_op)
for i in range(10):
- self.assertEqual(i * i, self.evaluate(get_next))
+ self.assertEqual(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -353,8 +351,8 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(table.initializer)
- self.evaluate(init_op)
+ sess.run(table.initializer)
+ sess.run(init_op)
sess.run(get_next)
sess.run(get_next)
with self.assertRaises(errors.OutOfRangeError):
@@ -371,11 +369,11 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(enqueue_op)
- self.evaluate(close_op)
- self.evaluate(init_op)
+ sess.run(enqueue_op)
+ sess.run(close_op)
+ sess.run(init_op)
for element in elements:
- self.assertEqual(element, self.evaluate(get_next))
+ self.assertEqual(element, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -396,9 +394,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(enqueue_op)
- self.evaluate(close_op)
- self.evaluate(init_op)
+ sess.run(enqueue_op)
+ sess.run(close_op)
+ sess.run(init_op)
for i in range(100):
self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]),
sorted(sess.run(get_next)))
@@ -415,15 +413,15 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(counter_var.initializer)
- self.evaluate(init_op)
+ sess.run(counter_var.initializer)
+ sess.run(init_op)
for i in range(10):
- self.assertEqual(i, self.evaluate(counter_var))
- self.assertEqual(i + 1, self.evaluate(get_next))
- self.assertEqual(10, self.evaluate(counter_var))
+ self.assertEqual(i, sess.run(counter_var))
+ self.assertEqual(i + 1, sess.run(get_next))
+ self.assertEqual(10, sess.run(counter_var))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- self.assertEqual(10, self.evaluate(counter_var))
+ self.assertEqual(10, sess.run(counter_var))
def testCaptureUninitializedVariableError(self):
counter_var = variable_scope.get_variable(
@@ -435,7 +433,7 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
with self.assertRaises(errors.NotFoundError):
sess.run(get_next)
@@ -447,14 +445,14 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
random_values = []
with self.assertRaises(errors.OutOfRangeError):
while True:
random_values.extend(sess.run(get_next))
self.assertEqual(10, len(random_values))
self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6)
- self.evaluate(init_op)
+ sess.run(init_op)
random_values_2 = []
with self.assertRaises(errors.OutOfRangeError):
while True:
@@ -473,8 +471,8 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
- random_values = self.evaluate(get_next)
+ sess.run(init_op)
+ random_values = sess.run(get_next)
# Assert that one of the next 99 batches yielded by the iterator is
# different from the first.
@@ -500,15 +498,15 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(counter_var.initializer)
- self.evaluate(init_op)
+ sess.run(counter_var.initializer)
+ sess.run(init_op)
for i in range(10):
- self.assertEqual(i, self.evaluate(counter_var))
- self.assertEqual(i, self.evaluate(get_next))
- self.assertEqual(10, self.evaluate(counter_var))
+ self.assertEqual(i, sess.run(counter_var))
+ self.assertEqual(i, sess.run(get_next))
+ self.assertEqual(10, sess.run(counter_var))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- self.assertEqual(10, self.evaluate(counter_var))
+ self.assertEqual(10, sess.run(counter_var))
def testMapDict(self):
iterator = (dataset_ops.Dataset.range(10)
@@ -519,9 +517,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for i in range(10):
- self.assertEqual(i * 2 + i**2, self.evaluate(get_next))
+ self.assertEqual(i * 2 + i**2, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -569,8 +567,8 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
- self.assertAllEqual(row**2, self.evaluate(get_next))
+ sess.run(init_op)
+ self.assertAllEqual(row**2, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -611,7 +609,7 @@
row = np.arange(6)
for num in [2, 3, 4]:
init_op, get_next = build_dataset(row, num)
- self.evaluate(init_op)
+ sess.run(init_op)
for i in range(6):
self.assertEqual(
(i // 2 if i % 2 else i * 2) if (num == 2 or num == 3) else i * 2,
@@ -652,7 +650,7 @@
row = np.arange(6)
for num in [2, 3, 4]:
init_op, get_next = build_dataset(row, num)
- self.evaluate(init_op)
+ sess.run(init_op)
self.assertAllEqual(
[x // 2 if (num == 2 or num == 3) else x * 2 for x in row],
sess.run(get_next))
@@ -697,7 +695,7 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
self.assertAllEqual([(x // 2 if x % 2 else x * 2) if
(num == 2 or num == 3) else x * 2 for x in row],
sess.run(get_next))
@@ -735,7 +733,7 @@
for buffer_size in [1, 10, 100, 1000]:
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
for i in range(100):
- self.assertEqual(i * i, self.evaluate(get_next))
+ self.assertEqual(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -753,10 +751,10 @@
sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size})
for i in range(event_will_be_set_after_consuming):
self.assertFalse(ev.is_set())
- self.assertEqual(i * i, self.evaluate(get_next))
+ self.assertEqual(i * i, sess.run(get_next))
ev.wait()
for i in range(event_will_be_set_after_consuming, 100):
- self.assertEqual(i * i, self.evaluate(get_next))
+ self.assertEqual(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -768,9 +766,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for i in range(10):
- self.assertEqual((i, 37.0), self.evaluate(get_next))
+ self.assertEqual((i, 37.0), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -789,9 +787,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for i in range(10):
- self.assertEqual((i, 37.0), self.evaluate(get_next))
+ self.assertEqual((i, 37.0), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -810,9 +808,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for i in range(10):
- actual = self.evaluate(get_next)
+ actual = sess.run(get_next)
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _sparse(i))
with self.assertRaises(errors.OutOfRangeError):
@@ -837,9 +835,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for i in range(10):
- actual = self.evaluate(get_next)
+ actual = sess.run(get_next)
self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
with self.assertRaises(errors.OutOfRangeError):
@@ -861,9 +859,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for i in range(100):
- self.assertEqual(i, self.evaluate(get_next))
+ self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -875,9 +873,9 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- self.evaluate(init_op)
+ sess.run(init_op)
for i in range(10):
- self.assertEqual((i, b"hello", 10), self.evaluate(get_next))
+ self.assertEqual((i, b"hello", 10), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -945,7 +943,7 @@
with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "BrokenConst"):
- self.evaluate(iterator.initializer)
+ sess.run(iterator.initializer)
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
@@ -972,7 +970,7 @@
get_next = iterator.get_next()
with self.cached_session() as sess:
- tids = self.evaluate(get_next)
+ tids = sess.run(get_next)
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
@@ -996,7 +994,7 @@
expected = map_fn(*sess.run(self.structuredElement(structure)))
else:
expected = map_fn(sess.run(self.structuredElement(structure)))
- self.assertEqual(expected, self.evaluate(get_next))
+ self.assertEqual(expected, sess.run(get_next))
@parameterized.named_parameters(
("Sequential", None),
@@ -1011,7 +1009,7 @@
with self.cached_session() as sess:
sess.run(iterator.initializer, feed_dict={captured_t: 42})
- self.assertEqual(42, self.evaluate(get_next))
+ self.assertEqual(42, sess.run(get_next))
@parameterized.named_parameters(
("1", 1, 1),
@@ -1030,7 +1028,7 @@
with self.cached_session(config=config) as sess:
for i in range(num_elements):
coordination_events[i].set()
- self.assertEqual(i * i, self.evaluate(get_next))
+ self.assertEqual(i * i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -1052,113 +1050,10 @@
for element in elements:
coordination_events[element].set()
- self.assertEqual(element * element, self.evaluate(get_next))
+ self.assertEqual(element * element, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
-class MapDatasetBenchmark(test.Benchmark):
-
- def benchmarkChainOfMaps(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- for mode in ["general", "single-threaded", "short-circuit"]:
- if mode == "general":
- map_fn = lambda x: x + 1
- use_inter_op_parallelism = True
- print_label = ""
- benchmark_label = ""
- if mode == "single-threaded":
- map_fn = lambda x: x + 1
- use_inter_op_parallelism = False
- print_label = " (single threaded mode)"
- benchmark_label = "_single_threaded"
- if mode == "short-circuit":
- map_fn = lambda x: x
- use_inter_op_parallelism = True # should not have any significance
- print_label = " (short circuit mode)"
- benchmark_label = "_short_circuit"
-
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset_ops.MapDataset(
- dataset,
- map_fn,
- use_inter_op_parallelism=use_inter_op_parallelism)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- print("Map dataset chain length%s: %d Median wall time: %f" %
- (print_label, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_map_dataset_chain_latency_%d%s" %
- (chain_length, benchmark_label))
-
- def benchmarkMapFanOut(self):
- fan_outs = [1, 2, 5, 10, 20, 50, 100]
- for fan_out in fan_outs:
- for mode in ["general", "single-threaded", "short-circuit"]:
- if mode == "general":
- map_fn = lambda *xs: [x + 1 for x in xs]
- use_inter_op_parallelism = True
- print_label = ""
- benchmark_label = ""
- if mode == "single-threaded":
- map_fn = lambda *xs: [x + 1 for x in xs]
- use_inter_op_parallelism = False
- print_label = " (single threaded mode)"
- benchmark_label = "_single_threaded"
- if mode == "short-circuit":
- map_fn = lambda *xs: xs
- use_inter_op_parallelism = True # should not have any significance
- print_label = " (short circuit mode)"
- benchmark_label = "_short_circuit"
-
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(
- tuple(0 for _ in range(fan_out))).repeat(None)
- dataset = dataset_ops.MapDataset(
- dataset,
- map_fn,
- use_inter_op_parallelism=use_inter_op_parallelism)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element[0].op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element[0].op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- print("Map dataset fan out%s: %d Median wall time: %f" %
- (print_label, fan_out, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d%s" % (fan_out,
- benchmark_label))
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
index ea6828e..886c9ac 100644
--- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""MultiDeviceIterator tests."""
+"""Tests for `tf.data.MultiDeviceIterator`."""
from __future__ import absolute_import
from __future__ import division
@@ -31,6 +31,7 @@
from tensorflow.python.platform import test
+# TODO(b/117581999): Add eager coverage.
class MultiDeviceIteratorTest(test_base.DatasetTestBase):
def testNoGetNext(self):
@@ -55,8 +56,8 @@
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
+ self.evaluate(elem_on_1)
+ self.evaluate(elem_on_2)
def testOneOnSameDevice(self):
with ops.device("/cpu:0"):
@@ -72,8 +73,8 @@
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
+ self.evaluate(elem_on_1)
+ self.evaluate(elem_on_2)
def testRepeatDevices(self):
with ops.device("/cpu:0"):
@@ -92,10 +93,10 @@
self.assertEqual(i + 2, self.evaluate(elem_on_3))
self.assertEqual(i + 3, self.evaluate(elem_on_4))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
- sess.run(elem_on_3)
- sess.run(elem_on_4)
+ self.evaluate(elem_on_1)
+ self.evaluate(elem_on_2)
+ self.evaluate(elem_on_3)
+ self.evaluate(elem_on_4)
def testNotFullyDivisible(self):
dataset = dataset_ops.Dataset.range(9)
@@ -111,8 +112,8 @@
self.assertEqual(i + 1, self.evaluate(elem_on_2))
self.assertEqual(8, self.evaluate(elem_on_1))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
+ self.evaluate(elem_on_1)
+ self.evaluate(elem_on_2)
def testGetNextAsOptional(self):
dataset = dataset_ops.Dataset.range(9)
@@ -143,9 +144,9 @@
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(elem_on_1_t)
+ self.evaluate(elem_on_1_t)
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(elem_on_2_t)
+ self.evaluate(elem_on_2_t)
def testUneven(self):
dataset = dataset_ops.Dataset.range(10)
@@ -161,8 +162,8 @@
for i in range(0, 10, 2):
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
+ self.evaluate(elem_on_1)
+ self.evaluate(elem_on_2)
def testMultipleInitializations(self):
with ops.device("/cpu:0"):
@@ -179,7 +180,8 @@
with self.test_session(config=config) as sess:
for i in range(1000):
sess.run(init_op, feed_dict={epoch: i})
- self.assertEqual([(i, 0), (i, 1)], sess.run([elem_on_1, elem_on_2]))
+ self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
+ elem_on_2]))
def testBasicGpu(self):
if not test_util.is_gpu_available():
@@ -197,8 +199,8 @@
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
+ self.evaluate(elem_on_1)
+ self.evaluate(elem_on_2)
def testUnevenGpu(self):
if not test_util.is_gpu_available():
@@ -217,8 +219,8 @@
for i in range(0, 10, 2):
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
+ self.evaluate(elem_on_1)
+ self.evaluate(elem_on_2)
def testGetNextAsOptionalGpu(self):
if not test_util.is_gpu_available():
@@ -252,9 +254,9 @@
self.assertFalse(self.evaluate(elem_on_1_has_value_t))
self.assertFalse(self.evaluate(elem_on_2_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(elem_on_1_t)
+ self.evaluate(elem_on_1_t)
with self.assertRaises(errors.InvalidArgumentError):
- sess.run(elem_on_2_t)
+ self.evaluate(elem_on_2_t)
def testOptimization(self):
dataset = dataset_ops.Dataset.range(10)
@@ -277,8 +279,8 @@
self.assertEqual(i, self.evaluate(elem_on_1))
self.assertEqual(i + 1, self.evaluate(elem_on_2))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(elem_on_1)
- sess.run(elem_on_2)
+ self.evaluate(elem_on_1)
+ self.evaluate(elem_on_2)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_test.py
similarity index 64%
rename from tensorflow/python/data/kernel_tests/optional_ops_test.py
rename to tensorflow/python/data/kernel_tests/optional_test.py
index 0981ff9..8640131 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the Optional data type wrapper."""
+"""Tests for `tf.data.Optional`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -33,18 +33,18 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
- @test_util.run_in_graph_and_eager_modes
def testFromValue(self):
opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
self.assertTrue(self.evaluate(opt.has_value()))
self.assertEqual(37.0, self.evaluate(opt.get_value()))
- @test_util.run_in_graph_and_eager_modes
def testFromStructuredValue(self):
opt = optional_ops.Optional.from_value({
"a": constant_op.constant(37.0),
@@ -56,7 +56,6 @@
"b": ([b"Foo"], b"Bar")
}, self.evaluate(opt.get_value()))
- @test_util.run_in_graph_and_eager_modes
def testFromSparseTensor(self):
st_0 = sparse_tensor.SparseTensorValue(
indices=np.array([[0]]),
@@ -75,7 +74,6 @@
self.assertAllEqual(expected.dense_shape,
self.evaluate(actual.dense_shape))
- @test_util.run_in_graph_and_eager_modes
def testFromNone(self):
value_structure = structure.TensorStructure(dtypes.float32, [])
opt = optional_ops.Optional.none_from_structure(value_structure)
@@ -90,7 +88,90 @@
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(opt.get_value())
- @test_util.run_in_graph_and_eager_modes
+ def testAddN(self):
+ devices = ["/cpu:0"]
+ if test_util.is_gpu_available():
+ devices.append("/gpu:0")
+ for device in devices:
+ with ops.device(device):
+ # With value
+ opt1 = optional_ops.Optional.from_value((1.0, 2.0))
+ opt2 = optional_ops.Optional.from_value((3.0, 4.0))
+
+ add_tensor = math_ops.add_n([opt1._variant_tensor,
+ opt2._variant_tensor])
+ add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure)
+ self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0))
+
+ # Without value
+ opt_none1 = optional_ops.Optional.none_from_structure(
+ opt1.value_structure)
+ opt_none2 = optional_ops.Optional.none_from_structure(
+ opt2.value_structure)
+ add_tensor = math_ops.add_n([opt_none1._variant_tensor,
+ opt_none2._variant_tensor])
+ add_opt = optional_ops._OptionalImpl(add_tensor,
+ opt_none1.value_structure)
+ self.assertFalse(self.evaluate(add_opt.has_value()))
+
+ def testNestedAddN(self):
+ devices = ["/cpu:0"]
+ if test_util.is_gpu_available():
+ devices.append("/gpu:0")
+ for device in devices:
+ with ops.device(device):
+ opt1 = optional_ops.Optional.from_value([1, 2.0])
+ opt2 = optional_ops.Optional.from_value([3, 4.0])
+ opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor))
+ opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor))
+
+ add_tensor = math_ops.add_n([opt3._variant_tensor,
+ opt4._variant_tensor])
+ add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure)
+ self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)
+
+ inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1],
+ opt1.value_structure)
+ self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])
+
+ def testZerosLike(self):
+ devices = ["/cpu:0"]
+ if test_util.is_gpu_available():
+ devices.append("/gpu:0")
+ for device in devices:
+ with ops.device(device):
+ # With value
+ opt = optional_ops.Optional.from_value((1.0, 2.0))
+ zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
+ zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
+ opt.value_structure)
+ self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
+ (0.0, 0.0))
+
+ # Without value
+ opt_none = optional_ops.Optional.none_from_structure(
+ opt.value_structure)
+ zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
+ zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
+ opt_none.value_structure)
+ self.assertFalse(self.evaluate(zeros_opt.has_value()))
+
+ def testNestedZerosLike(self):
+ devices = ["/cpu:0"]
+ if test_util.is_gpu_available():
+ devices.append("/gpu:0")
+ for device in devices:
+ with ops.device(device):
+ opt1 = optional_ops.Optional.from_value(1.0)
+ opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)
+
+ zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
+ zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
+ opt2.value_structure)
+ inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(),
+ opt1.value_structure)
+ self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)
+
def testCopyToGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
@@ -120,6 +201,41 @@
self.evaluate(gpu_optional_with_value_values))
self.assertFalse(self.evaluate(gpu_optional_none_has_value))
+ def testNestedCopyToGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ with ops.device("/cpu:0"):
+ optional_with_value = optional_ops.Optional.from_value(
+ (constant_op.constant(37.0), constant_op.constant("Foo"),
+ constant_op.constant(42)))
+ optional_none = optional_ops.Optional.none_from_structure(
+ structure.TensorStructure(dtypes.float32, []))
+ nested_optional = optional_ops.Optional.from_value(
+ (optional_with_value._variant_tensor, optional_none._variant_tensor,
+ 1.0))
+
+ with ops.device("/gpu:0"):
+ gpu_nested_optional = optional_ops._OptionalImpl(
+ array_ops.identity(nested_optional._variant_tensor),
+ nested_optional.value_structure)
+
+ gpu_nested_optional_has_value = gpu_nested_optional.has_value()
+ gpu_nested_optional_values = gpu_nested_optional.get_value()
+
+ self.assertTrue(self.evaluate(gpu_nested_optional_has_value))
+
+ inner_with_value = optional_ops._OptionalImpl(
+ gpu_nested_optional_values[0], optional_with_value.value_structure)
+
+ inner_none = optional_ops._OptionalImpl(
+ gpu_nested_optional_values[1], optional_none.value_structure)
+
+ self.assertEqual((37.0, b"Foo", 42),
+ self.evaluate(inner_with_value.get_value()))
+ self.assertFalse(self.evaluate(inner_none.has_value()))
+ self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))
+
def _assertElementValueEqual(self, expected, actual):
if isinstance(expected, dict):
self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
@@ -151,7 +267,8 @@
optional_ops.OptionalStructure(
structure.TensorStructure(dtypes.float32, []))),
)
- def testOptionalStructure(self, tf_value_fn, expected_value_structure):
+ def testSkipEagerOptionalStructure(self, tf_value_fn,
+ expected_value_structure):
tf_value = tf_value_fn()
opt = optional_ops.Optional.from_value(tf_value)
@@ -205,7 +322,8 @@
indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
dense_shape=[2, 2])}, False),
)
- def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu):
+ def testSkipEagerIteratorGetNextAsOptional(self, np_value, tf_value_fn,
+ works_on_gpu):
if not works_on_gpu and test.is_gpu_available():
self.skipTest("Test case not yet supported on GPU.")
ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
@@ -227,7 +345,7 @@
# For each element of the dataset, assert that the optional evaluates to
# the expected value.
- self.evaluate(iterator.initializer)
+ sess.run(iterator.initializer)
for _ in range(3):
elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
self.assertTrue(elem_has_value)
@@ -236,7 +354,7 @@
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
# false, and attempting to get the value will fail.
for _ in range(2):
- self.assertFalse(self.evaluate(elem_has_value_t))
+ self.assertFalse(sess.run(elem_has_value_t))
with self.assertRaises(errors.InvalidArgumentError):
sess.run(elem_value_t)
diff --git a/tensorflow/python/data/kernel_tests/padded_batch_test.py b/tensorflow/python/data/kernel_tests/padded_batch_test.py
new file mode 100644
index 0000000..5f20d7b
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/padded_batch_test.py
@@ -0,0 +1,240 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.padded_batch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+def _random_seq_lens(count):
+ return np.random.randint(20, size=(count,)).astype(np.int32)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('default_padding', _random_seq_lens(32), 4, [-1], False),
+ ('constant_padding', _random_seq_lens(32), 4, [25], False),
+ ('uneven_with_remainder', _random_seq_lens(34), 4, [-1], False),
+ ('uneven_without_remainder', _random_seq_lens(34), 4, [-1], True),
+ )
+ def testPaddedBatchDataset(self, seq_lens, batch_size, padded_shapes,
+ drop_remainder):
+ """Tests the padded batch dataset logic for various input configurations.
+
+ Args:
+ seq_lens: the input sequence lengths
+ batch_size: the batch size
+ padded_shapes: the padded shapes to use
+ drop_remainder: whether a smaller batch size should be produced if batch
+ size does not divide number of inputs evenly
+ """
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ lambda x: array_ops.fill([x], x)).padded_batch(
+ batch_size=batch_size,
+ drop_remainder=drop_remainder,
+ padded_shapes=padded_shapes)
+
+ num_full_batches = len(seq_lens) // batch_size
+ get_next = self.getNext(dataset)
+ for i in range(num_full_batches):
+ result = self.evaluate(get_next())
+ padded_len = padded_shapes[0]
+ if padded_len is None or padded_len == -1:
+ padded_len = np.max(result) if result.size > 0 else 0
+ self.assertEqual((batch_size, padded_len), result.shape)
+ for j in range(batch_size):
+ seq_len = seq_lens[(i * batch_size) + j]
+ self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
+ self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
+
+ if not drop_remainder and len(seq_lens) % batch_size > 0:
+ result = self.evaluate(get_next())
+ padded_len = np.max(result) if result.size > 0 else 0
+ self.assertEqual((len(seq_lens) % batch_size, padded_len), result.shape)
+ for j in range(len(seq_lens) % batch_size):
+ seq_len = seq_lens[num_full_batches * batch_size + j]
+ self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
+ self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ def testPaddedBatchShortPadding(self):
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(
+ [6, 5, 5, 5, 5]).map(lambda x: array_ops.fill([x], x)).padded_batch(
+ batch_size=4, padded_shapes=[5]))
+ self.assertDatasetProduces(
+ dataset, expected_error=(errors.DataLossError, ''))
+
+ def testPaddedBatchEmptyTensors(self):
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(
+ [0, 0, 0, 0]).map(lambda x: array_ops.fill([x], x)).padded_batch(
+ batch_size=4, padded_shapes=[-1]))
+ self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])
+
+ def testPaddedBatchDatasetNonDefaultPadding(self):
+
+ def fill_tuple(x):
+ filled = array_ops.fill([x], x)
+ return (filled, string_ops.as_string(filled))
+
+ random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(random_seq_lens).map(fill_tuple)
+ .padded_batch(
+ 4, padded_shapes=([-1], [-1]), padding_values=(-1, '<end>')))
+
+ get_next = self.getNext(dataset)
+ for i in range(8):
+ result = self.evaluate(get_next())
+ padded_len = np.max(result[0])
+ self.assertEqual((4, padded_len), result[0].shape)
+ self.assertEqual((4, padded_len), result[1].shape)
+ for j in range(4):
+ seq_len = random_seq_lens[(i * 4) + j]
+ self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
+ self.assertAllEqual(result[0][j, seq_len:],
+ [-1] * (padded_len - seq_len))
+ self.assertAllEqual(result[1][j, :seq_len],
+ [compat.as_bytes(str(seq_len))] * seq_len)
+ self.assertAllEqual(result[1][j, seq_len:],
+ [b'<end>'] * (padded_len - seq_len))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ def testPaddedBatchDatasetUnicode(self):
+ # See GitHub issue 16149
+ def generator():
+ data = [[u'Простой', u'тест', u'юникода'],
+ [u'никогда', u'не', u'бывает', u'простым']]
+
+ for seq in data:
+ yield seq, [0, 1, 2, 3]
+
+ dataset = dataset_ops.Dataset.from_generator(
+ generator, (dtypes.string, dtypes.int32),
+ (tensor_shape.TensorShape([None]), tensor_shape.TensorShape([None])))
+ padded_dataset = dataset.padded_batch(
+ 2, padded_shapes=([None], [None]), padding_values=('', 0))
+ next_element = self.getNext(padded_dataset)
+ self.evaluate(next_element())
+
+ def testSkipEagerPaddedBatchDatasetShapeSpecifications(self):
+ int_placeholder = array_ops.placeholder(dtypes.int32)
+ float_placeholder = array_ops.placeholder(dtypes.float32)
+ string_placeholder = array_ops.placeholder(dtypes.string)
+ input_dataset = dataset_ops.Dataset.from_tensors(
+ (int_placeholder, float_placeholder, string_placeholder))
+
+ # Test different ways of specifying the `padded_shapes` argument.
+ dynamic_padding_from_tensor_shapes = input_dataset.padded_batch(
+ 32,
+ padded_shapes=(tensor_shape.TensorShape([None]),
+ tensor_shape.TensorShape([None, None]),
+ tensor_shape.TensorShape([37])))
+ dynamic_padding_from_lists = input_dataset.padded_batch(
+ 32, padded_shapes=([None], [None, None], [37]))
+ dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch(
+ 32, padded_shapes=([-1], [-1, -1], [37]))
+ dynamic_padding_from_tensors = input_dataset.padded_batch(
+ 32,
+ padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64),
+ constant_op.constant([-1, -1], dtype=dtypes.int64),
+ constant_op.constant([37], dtype=dtypes.int64)))
+
+ for dataset in [
+ dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
+ dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
+ ]:
+ self.assertEqual([None, None], dataset.output_shapes[0].as_list())
+ self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
+ self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
+
+ def testPaddedBatchSparseError(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
+
+ with self.assertRaises(TypeError):
+ _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10)
+
+ def testPaddedBatchShapeError(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ _ = dataset_ops.Dataset.range(10).padded_batch(5, padded_shapes=[1])
+
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(3,\).'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=[1])
+
+ with self.assertRaisesRegexp(
+ ValueError, r'Padded shape .* must be a 1-D tensor '
+ r'of tf.int64 values, but its shape was \(2, 2\).'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=[[1, 1], [1, 1]])
+
+ with self.assertRaisesRegexp(
+ TypeError, r'Padded shape .* must be a 1-D tensor '
+ r'of tf.int64 values, but its element type was float32.'):
+ _ = dataset_ops.Dataset.from_tensors([1, 2, 3]).padded_batch(
+ 5, padded_shapes=constant_op.constant([1., 2., 3.]))
+
+ with self.assertRaisesRegexp(
+ ValueError, r'The padded shape \(1,\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ shape_as_tensor = constant_op.constant([1], dtype=dtypes.int64)
+ _ = dataset_ops.Dataset.range(10).padded_batch(
+ 5, padded_shapes=shape_as_tensor)
+
+ def testSkipEagerPaddedBatchShapeError(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'The padded shape \((\?|None), (\?|None)\) is not compatible with the '
+ r'corresponding input component shape \(\).'):
+ shape_as_tensor = array_ops.placeholder(dtypes.int64, shape=[2])
+ _ = dataset_ops.Dataset.range(10).padded_batch(
+ 5, padded_shapes=shape_as_tensor)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
deleted file mode 100644
index af326ec..0000000
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test PrefetchDataset."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @parameterized.parameters((-1), (0), (5))
- def testBufferSize(self, buffer_size):
- buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = dataset_ops.Dataset.range(10).prefetch(
- buffer_size=buffer_size_t).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
- for m in range(10):
- self.assertEqual(m, self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @parameterized.parameters((-2), (-42))
- def testInvalidBufferSize(self, buffer_size):
- buffer_size_t = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = dataset_ops.Dataset.range(10).prefetch(
- buffer_size=buffer_size_t).make_initializable_iterator()
- init_op = iterator.initializer
-
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
- with self.cached_session() as sess:
- sess.run(init_op, feed_dict={buffer_size_t: buffer_size})
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/prefetch_test.py b/tensorflow/python/data/kernel_tests/prefetch_test.py
new file mode 100644
index 0000000..a143ba0
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/prefetch_test.py
@@ -0,0 +1,44 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.prefetch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class PrefetchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.parameters((-1), (0), (5))
+ def testBufferSize(self, buffer_size):
+ dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
+ self.assertDatasetProduces(dataset, expected_output=range(10))
+
+ @parameterized.parameters((-2), (-42))
+ def testInvalidBufferSize(self, buffer_size):
+ dataset = dataset_ops.Dataset.range(10).prefetch(buffer_size=buffer_size)
+ self.assertDatasetProduces(
+ dataset, expected_error=(errors.InvalidArgumentError, "buffer_size"))
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/range_test.py b/tensorflow/python/data/kernel_tests/range_test.py
new file mode 100644
index 0000000..3f5d25e
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/range_test.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.range()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class RangeTest(test_base.DatasetTestBase):
+
+ def testStop(self):
+ dataset = dataset_ops.Dataset.range(5)
+ self.assertDatasetProduces(dataset, expected_output=range(5))
+
+ def testStartStop(self):
+ start, stop = 2, 5
+ dataset = dataset_ops.Dataset.range(start, stop)
+ self.assertDatasetProduces(dataset, expected_output=range(2, 5))
+
+ def testStartStopStep(self):
+ start, stop, step = 2, 10, 2
+ dataset = dataset_ops.Dataset.range(start, stop, step)
+ self.assertDatasetProduces(dataset, expected_output=range(2, 10, 2))
+
+ def testZeroStep(self):
+ start, stop, step = 2, 10, 0
+ dataset = dataset_ops.Dataset.range(start, stop, step)
+ self.assertDatasetProduces(
+ dataset, expected_error=(errors.InvalidArgumentError, ""))
+
+ def testNegativeStep(self):
+ start, stop, step = 2, 10, -1
+ dataset = dataset_ops.Dataset.range(start, stop, step)
+ self.assertDatasetProduces(dataset, expected_output=range(2, 10, -1))
+
+ def testStopLessThanStart(self):
+ start, stop = 10, 2
+ dataset = dataset_ops.Dataset.range(start, stop)
+ self.assertDatasetProduces(dataset, expected_output=range(10, 2))
+
+ def testStopLessThanStartWithPositiveStep(self):
+ start, stop, step = 10, 2, 2
+ dataset = dataset_ops.Dataset.range(start, stop, step)
+ self.assertDatasetProduces(dataset, expected_output=range(10, 2, 2))
+
+ def testStopLessThanStartWithNegativeStep(self):
+ start, stop, step = 10, 2, -1
+ dataset = dataset_ops.Dataset.range(start, stop, step)
+ self.assertDatasetProduces(dataset, expected_output=range(10, 2, -1))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
deleted file mode 100644
index e26381e..0000000
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ /dev/null
@@ -1,846 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import zlib
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.ops import readers
-from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.lib.io import python_io
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-try:
- import psutil # pylint: disable=g-import-not-at-top
- psutil_import_succeeded = True
-except ImportError:
- psutil_import_succeeded = False
-
-
-class TextLineDatasetTest(test_base.DatasetTestBase):
-
- def _lineText(self, f, l):
- return compat.as_bytes("%d: %d" % (f, l))
-
- def _createFiles(self,
- num_files,
- num_lines,
- crlf=False,
- compression_type=None):
- filenames = []
- for i in range(num_files):
- fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
- filenames.append(fn)
- contents = []
- for j in range(num_lines):
- contents.append(self._lineText(i, j))
- # Always include a newline after the record unless it is
- # at the end of the file, in which case we include it
- if j + 1 != num_lines or i == 0:
- contents.append(b"\r\n" if crlf else b"\n")
- contents = b"".join(contents)
-
- if not compression_type:
- with open(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "GZIP":
- with gzip.GzipFile(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "ZLIB":
- contents = zlib.compress(contents)
- with open(fn, "wb") as f:
- f.write(contents)
- else:
- raise ValueError("Unsupported compression_type", compression_type)
-
- return filenames
-
- def _testTextLineDataset(self, compression_type=None):
- test_filenames = self._createFiles(
- 2, 5, crlf=True, compression_type=compression_type)
- filenames = array_ops.placeholder(dtypes.string, shape=[None])
- num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_dataset = readers.TextLineDataset(
- filenames, compression_type=compression_type).repeat(num_epochs)
- batch_dataset = repeat_dataset.batch(batch_size)
-
- iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
- init_op = iterator.make_initializer(repeat_dataset)
- init_batch_op = iterator.make_initializer(batch_dataset)
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Basic test: read from file 0.
- sess.run(
- init_op, feed_dict={filenames: [test_filenames[0]],
- num_epochs: 1})
- for i in range(5):
- self.assertEqual(self._lineText(0, i), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Basic test: read from file 1.
- sess.run(
- init_op, feed_dict={filenames: [test_filenames[1]],
- num_epochs: 1})
- for i in range(5):
- self.assertEqual(self._lineText(1, i), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Basic test: read from both files.
- sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
- for j in range(2):
- for i in range(5):
- self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test repeated iteration through both files.
- sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10})
- for _ in range(10):
- for j in range(2):
- for i in range(5):
- self.assertEqual(self._lineText(j, i), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test batched and repeated iteration through both files.
- sess.run(
- init_batch_op,
- feed_dict={filenames: test_filenames,
- num_epochs: 10,
- batch_size: 5})
- for _ in range(10):
- self.assertAllEqual([self._lineText(0, i) for i in range(5)],
- sess.run(get_next))
- self.assertAllEqual([self._lineText(1, i) for i in range(5)],
- sess.run(get_next))
-
- def testTextLineDatasetNoCompression(self):
- self._testTextLineDataset()
-
- def testTextLineDatasetGzipCompression(self):
- self._testTextLineDataset(compression_type="GZIP")
-
- def testTextLineDatasetZlibCompression(self):
- self._testTextLineDataset(compression_type="ZLIB")
-
- def testTextLineDatasetBuffering(self):
- test_filenames = self._createFiles(2, 5, crlf=True)
-
- repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
- iterator = repeat_dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- for j in range(2):
- for i in range(5):
- self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next()))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
- def testIteratorResourceCleanup(self):
- filename = os.path.join(self.get_temp_dir(), "text.txt")
- with open(filename, "wt") as f:
- for i in range(3):
- f.write("%d\n" % (i,))
- with context.eager_mode():
- first_iterator = iter(readers.TextLineDataset(filename))
- self.assertEqual(b"0", next(first_iterator).numpy())
- second_iterator = iter(readers.TextLineDataset(filename))
- self.assertEqual(b"0", next(second_iterator).numpy())
- # Eager kernel caching is based on op attributes, which includes the
- # Dataset's output shape. Create a different kernel to test that they
- # don't create resources with the same names.
- different_kernel_iterator = iter(
- readers.TextLineDataset(filename).repeat().batch(16))
- self.assertEqual([16], next(different_kernel_iterator).shape)
- # Remove our references to the Python Iterator objects, which (assuming no
- # reference cycles) is enough to trigger DestroyResourceOp and close the
- # partially-read files.
- del first_iterator
- del second_iterator
- del different_kernel_iterator
- if not psutil_import_succeeded:
- self.skipTest(
- "psutil is required to check that we've closed our files.")
- open_files = psutil.Process().open_files()
- self.assertNotIn(filename, [open_file.path for open_file in open_files])
-
-
-class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
-
- def setUp(self):
- super(FixedLengthRecordReaderTest, self).setUp()
- self._num_files = 2
- self._num_records = 7
- self._header_bytes = 5
- self._record_bytes = 3
- self._footer_bytes = 2
-
- def _record(self, f, r):
- return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
-
- def _createFiles(self, compression_type=None):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
- filenames.append(fn)
-
- contents = []
- contents.append(b"H" * self._header_bytes)
- for j in range(self._num_records):
- contents.append(self._record(i, j))
- contents.append(b"F" * self._footer_bytes)
- contents = b"".join(contents)
-
- if not compression_type:
- with open(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "GZIP":
- with gzip.GzipFile(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "ZLIB":
- contents = zlib.compress(contents)
- with open(fn, "wb") as f:
- f.write(contents)
- else:
- raise ValueError("Unsupported compression_type", compression_type)
-
- return filenames
-
- def _testFixedLengthRecordDataset(self, compression_type=None):
- test_filenames = self._createFiles(compression_type=compression_type)
- filenames = array_ops.placeholder(dtypes.string, shape=[None])
- num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_dataset = (
- readers.FixedLengthRecordDataset(
- filenames,
- self._record_bytes,
- self._header_bytes,
- self._footer_bytes,
- compression_type=compression_type).repeat(num_epochs))
- batch_dataset = repeat_dataset.batch(batch_size)
-
- iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
- init_op = iterator.make_initializer(repeat_dataset)
- init_batch_op = iterator.make_initializer(batch_dataset)
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Basic test: read from file 0.
- sess.run(
- init_op, feed_dict={filenames: [test_filenames[0]],
- num_epochs: 1})
- for i in range(self._num_records):
- self.assertEqual(self._record(0, i), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Basic test: read from file 1.
- sess.run(
- init_op, feed_dict={filenames: [test_filenames[1]],
- num_epochs: 1})
- for i in range(self._num_records):
- self.assertEqual(self._record(1, i), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Basic test: read from both files.
- sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1})
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertEqual(self._record(j, i), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test repeated iteration through both files.
- sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10})
- for _ in range(10):
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertEqual(self._record(j, i), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test batched and repeated iteration through both files.
- sess.run(
- init_batch_op,
- feed_dict={
- filenames: test_filenames,
- num_epochs: 10,
- batch_size: self._num_records
- })
- for _ in range(10):
- for j in range(self._num_files):
- self.assertAllEqual(
- [self._record(j, i) for i in range(self._num_records)],
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFixedLengthRecordDatasetNoCompression(self):
- self._testFixedLengthRecordDataset()
-
- def testFixedLengthRecordDatasetGzipCompression(self):
- self._testFixedLengthRecordDataset(compression_type="GZIP")
-
- def testFixedLengthRecordDatasetZlibCompression(self):
- self._testFixedLengthRecordDataset(compression_type="ZLIB")
-
- def testFixedLengthRecordDatasetBuffering(self):
- test_filenames = self._createFiles()
- dataset = readers.FixedLengthRecordDataset(
- test_filenames,
- self._record_bytes,
- self._header_bytes,
- self._footer_bytes,
- buffer_size=10)
- iterator = dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertEqual(self._record(j, i), sess.run(iterator.get_next()))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
- def testFixedLengthRecordDatasetWrongSize(self):
- test_filenames = self._createFiles()
- dataset = readers.FixedLengthRecordDataset(
- test_filenames,
- self._record_bytes + 1, # Incorrect record length.
- self._header_bytes,
- self._footer_bytes,
- buffer_size=10)
- iterator = dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- r"Excluding the header \(5 bytes\) and footer \(2 bytes\), input "
- r"file \".*fixed_length_record.0.txt\" has body length 21 bytes, "
- r"which is not an exact multiple of the record length \(4 bytes\)."):
- sess.run(iterator.get_next())
-
- def _iterator_checkpoint_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _save_op(self, iterator_resource):
- iterator_state_variant = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- save_op = io_ops.write_file(
- self._iterator_checkpoint_path(),
- parsing_ops.serialize_tensor(iterator_state_variant))
- return save_op
-
- def _restore_op(self, iterator_resource):
- iterator_state_variant = parsing_ops.parse_tensor(
- io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
- restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
- iterator_state_variant)
- return restore_op
-
- def _build_iterator_graph(self, num_epochs):
- filenames = self._createFiles()
- dataset = (readers.FixedLengthRecordDataset(
- filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
- .repeat(num_epochs))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next_op = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next_op, save_op, restore_op
-
- def _restore_iterator(self):
- output_types = dtypes.string
- output_shapes = tensor_shape.scalar()
- iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
- get_next = iterator.get_next()
- restore_op = self._restore_op(iterator._iterator_resource)
- return restore_op, get_next
-
- def testSaveRestore(self):
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- self.evaluate(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- self.evaluate(save_op)
- break
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- def testInitThenRestore(self):
- # Note: Calling init_op before restore_op is redundant. This test just makes
- # sure we do not fail if restore is called on an already initialized
- # iterator resource.
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- self.evaluate(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- self.evaluate(save_op)
- break
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(init_op)
- self.evaluate(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- def testRestoreInModifiedGraph(self):
- num_epochs = 10
- num_epochs_1 = 20
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- self.evaluate(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- self.evaluate(save_op)
- break
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs_1)
- with self.session(graph=g) as sess:
- self.evaluate(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- def testRestoreWithoutBuildingDatasetGraph(self):
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- self.evaluate(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- self.evaluate(save_op)
- break
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- with ops.Graph().as_default() as g:
- restore_op, get_next_op = self._restore_iterator()
- with self.session(graph=g) as sess:
- self.evaluate(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- def testRestoreUnusedIterator(self):
- num_epochs = 10
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- self.evaluate(restore_op)
- # Save unused iterator.
- self.evaluate(save_op)
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(restore_op)
- for _ in range(num_epochs * self._num_files * self._num_records):
- self.evaluate(get_next_op)
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
- def testRestoreExhaustedIterator(self):
- num_epochs = 10
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- self.evaluate(restore_op)
- for _ in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- self.assertEqual(self._record(f, r), self.evaluate(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
- self.evaluate(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.session(graph=g) as sess:
- self.evaluate(restore_op)
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(get_next_op)
-
-
-class TFRecordDatasetTest(test_base.DatasetTestBase):
-
- def setUp(self):
- super(TFRecordDatasetTest, self).setUp()
- self._num_files = 2
- self._num_records = 7
-
- self.test_filenames = self._createFiles()
-
- self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
- self.num_epochs = array_ops.placeholder_with_default(
- constant_op.constant(1, dtypes.int64), shape=[])
- self.compression_type = array_ops.placeholder_with_default("", shape=[])
- self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_dataset = readers.TFRecordDataset(self.filenames,
- self.compression_type).repeat(
- self.num_epochs)
- batch_dataset = repeat_dataset.batch(self.batch_size)
-
- iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
- self.init_op = iterator.make_initializer(repeat_dataset)
- self.init_batch_op = iterator.make_initializer(batch_dataset)
- self.get_next = iterator.get_next()
-
- def _record(self, f, r):
- return compat.as_bytes("Record %d of file %d" % (r, f))
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = python_io.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._record(i, j))
- writer.close()
- return filenames
-
- def testReadOneEpoch(self):
- with self.cached_session() as sess:
- # Basic test: read from file 0.
- sess.run(
- self.init_op,
- feed_dict={
- self.filenames: [self.test_filenames[0]],
- self.num_epochs: 1
- })
- for i in range(self._num_records):
- self.assertAllEqual(self._record(0, i), sess.run(self.get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
-
- # Basic test: read from file 1.
- sess.run(
- self.init_op,
- feed_dict={
- self.filenames: [self.test_filenames[1]],
- self.num_epochs: 1
- })
- for i in range(self._num_records):
- self.assertAllEqual(self._record(1, i), sess.run(self.get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
-
- # Basic test: read from both files.
- sess.run(
- self.init_op,
- feed_dict={self.filenames: self.test_filenames,
- self.num_epochs: 1})
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
-
- def testReadTenEpochs(self):
- with self.cached_session() as sess:
- sess.run(
- self.init_op,
- feed_dict={self.filenames: self.test_filenames,
- self.num_epochs: 10})
- for _ in range(10):
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
-
- def testReadTenEpochsOfBatches(self):
- with self.cached_session() as sess:
- sess.run(
- self.init_batch_op,
- feed_dict={
- self.filenames: self.test_filenames,
- self.num_epochs: 10,
- self.batch_size: self._num_records
- })
- for _ in range(10):
- for j in range(self._num_files):
- values = sess.run(self.get_next)
- self.assertAllEqual(
- [self._record(j, i) for i in range(self._num_records)], values)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
-
- def testReadZlibFiles(self):
- zlib_files = []
- for i, fn in enumerate(self.test_filenames):
- with open(fn, "rb") as f:
- cdata = zlib.compress(f.read())
-
- zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
- with open(zfn, "wb") as f:
- f.write(cdata)
- zlib_files.append(zfn)
-
- with self.cached_session() as sess:
- sess.run(
- self.init_op,
- feed_dict={self.filenames: zlib_files,
- self.compression_type: "ZLIB"})
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
-
- def testReadGzipFiles(self):
- gzip_files = []
- for i, fn in enumerate(self.test_filenames):
- with open(fn, "rb") as f:
- gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
- with gzip.GzipFile(gzfn, "wb") as gzf:
- gzf.write(f.read())
- gzip_files.append(gzfn)
-
- with self.cached_session() as sess:
- sess.run(
- self.init_op,
- feed_dict={self.filenames: gzip_files,
- self.compression_type: "GZIP"})
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertAllEqual(self._record(j, i), sess.run(self.get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.get_next)
-
- def testReadWithBuffer(self):
- one_mebibyte = 2**20
- d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte)
- iterator = d.make_one_shot_iterator()
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testReadFromDatasetOfFiles(self):
- files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames)
- d = readers.TFRecordDataset(files)
- iterator = d.make_one_shot_iterator()
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- for j in range(self._num_files):
- for i in range(self._num_records):
- self.assertAllEqual(self._record(j, i), self.evaluate(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testReadTenEpochsFromDatasetOfFilesInParallel(self):
- files = dataset_ops.Dataset.from_tensor_slices(
- self.test_filenames).repeat(10)
- d = readers.TFRecordDataset(files, num_parallel_reads=4)
- iterator = d.make_one_shot_iterator()
- next_element = iterator.get_next()
- expected = []
- actual = []
- with self.cached_session() as sess:
- for _ in range(10):
- for j in range(self._num_files):
- for i in range(self._num_records):
- expected.append(self._record(j, i))
- actual.append(sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self.assertEqual(sorted(expected), sorted(actual))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py b/tensorflow/python/data/kernel_tests/reduce_test.py
similarity index 73%
rename from tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
rename to tensorflow/python/data/kernel_tests/reduce_test.py
index d7f3988..d7b6539 100644
--- a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/reduce_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.Dataset.reduce()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -22,21 +22,24 @@
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+@test_util.run_all_in_graph_and_eager_modes
+class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
def testSum(self):
for i in range(10):
ds = dataset_ops.Dataset.range(1, i + 1)
- result = ds.reduce(np.int64(0), lambda x, y: x + y)
- with self.cached_session() as sess:
- self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
+ result = ds.reduce(
+ constant_op.constant(0, dtype=dtypes.int64), lambda x, y: x + y)
+ self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
def testSumTuple(self):
@@ -47,9 +50,8 @@
for i in range(10):
ds = dataset_ops.Dataset.range(1, i + 1)
ds = dataset_ops.Dataset.zip((ds, ds))
- result = ds.reduce(np.int64(0), reduce_fn)
- with self.cached_session() as sess:
- self.assertEqual(((i + 1) * i), self.evaluate(result))
+ result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn)
+ self.assertEqual(((i + 1) * i), self.evaluate(result))
def testSumAndCount(self):
@@ -59,13 +61,14 @@
for i in range(10):
ds = dataset_ops.Dataset.range(1, i + 1)
- result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
- with self.cached_session() as sess:
- s, c = self.evaluate(result)
- self.assertEqual(((i + 1) * i) // 2, s)
- self.assertEqual(i, c)
+ result = ds.reduce((constant_op.constant(0, dtype=dtypes.int64),
+ constant_op.constant(0, dtype=dtypes.int64)),
+ reduce_fn)
+ s, c = self.evaluate(result)
+ self.assertEqual(((i + 1) * i) // 2, s)
+ self.assertEqual(i, c)
- def testSquareUsingPlaceholder(self):
+ def testSkipEagerSquareUsingPlaceholder(self):
delta = array_ops.placeholder(dtype=dtypes.int64)
def reduce_fn(state, _):
@@ -92,9 +95,7 @@
for i in range(10):
ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
result = ds.reduce(make_sparse_fn(0), reduce_fn)
- with self.cached_session() as sess:
- self.assertSparseValuesEqual(
- make_sparse_fn(i + 1), self.evaluate(result))
+ self.assertSparseValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
def testNested(self):
@@ -116,10 +117,10 @@
for i in range(10):
ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
result = ds.reduce(map_fn(0), reduce_fn)
- with self.cached_session() as sess:
- result = self.evaluate(result)
- self.assertEqual(((i + 1) * i) // 2, result["dense"])
- self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
+ result = self.evaluate(result)
+ self.assertEqual(((i + 1) * i) // 2, result["dense"])
+ self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/repeat_test.py b/tensorflow/python/data/kernel_tests/repeat_test.py
new file mode 100644
index 0000000..4ef2fc1
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/repeat_test.py
@@ -0,0 +1,84 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.repeat()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class RepeatTest(test_base.DatasetTestBase):
+
+ def testRepeatTensorDataset(self):
+ """Test a dataset that repeats its input multiple times."""
+ components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
+ # This placeholder can be fed when dataset-definition subgraph
+ # runs (i.e. `init_op` below) to configure the number of
+ # repetitions used in a particular iterator.
+
+ def do_test(count):
+ dataset = dataset_ops.Dataset.from_tensors(components).repeat(count)
+ self.assertEqual([c.shape for c in components],
+ [shape for shape in dataset.output_shapes])
+ self.assertDatasetProduces(dataset, [components] * count)
+
+ # Test a finite repetition.
+ do_test(3)
+
+ # test a different finite repetition.
+ do_test(7)
+
+ # Test an empty repetition.
+ do_test(0)
+
+ # Test an infinite repetition.
+ # NOTE(mrry): There's not a good way to test that the sequence
+ # actually is infinite.
+ dataset = dataset_ops.Dataset.from_tensors(components).repeat(-1)
+ self.assertEqual([c.shape for c in components],
+ [shape for shape in dataset.output_shapes])
+ get_next = self.getNext(dataset)
+ for _ in range(17):
+ results = self.evaluate(get_next())
+ for component, result_component in zip(components, results):
+ self.assertAllEqual(component, result_component)
+
+ def testRepeatRepeatTensorDataset(self):
+ """Test the composition of repeat datasets."""
+ components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
+ inner_count, outer_count = 7, 14
+
+ dataset = dataset_ops.Dataset.from_tensors(components).repeat(
+ inner_count).repeat(outer_count)
+ self.assertEqual([c.shape for c in components],
+ [shape for shape in dataset.output_shapes])
+ self.assertDatasetProduces(dataset,
+ [components] * (inner_count * outer_count))
+
+ def testRepeatEmptyDataset(self):
+ """Test that repeating an empty dataset does not hang."""
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10).repeat(-1)
+ self.assertDatasetProduces(dataset, [])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
deleted file mode 100644
index 946aa01..0000000
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ /dev/null
@@ -1,210 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class SequenceDatasetTest(test_base.DatasetTestBase):
-
- def testRepeatTensorDataset(self):
- """Test a dataset that repeats its input multiple times."""
- components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
- # This placeholder can be fed when dataset-definition subgraph
- # runs (i.e. `init_op` below) to configure the number of
- # repetitions used in a particular iterator.
- count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = (dataset_ops.Dataset.from_tensors(components)
- .repeat(count_placeholder).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([c.shape for c in components],
- [t.shape for t in get_next])
-
- with self.cached_session() as sess:
- # Test a finite repetition.
- sess.run(init_op, feed_dict={count_placeholder: 3})
- for _ in range(3):
- results = self.evaluate(get_next)
- for component, result_component in zip(components, results):
- self.assertAllEqual(component, result_component)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test a different finite repetition.
- sess.run(init_op, feed_dict={count_placeholder: 7})
- for _ in range(7):
- results = self.evaluate(get_next)
- for component, result_component in zip(components, results):
- self.assertAllEqual(component, result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test an empty repetition.
- sess.run(init_op, feed_dict={count_placeholder: 0})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test an infinite repetition.
- # NOTE(mrry): There's not a good way to test that the sequence
- # actually is infinite.
- sess.run(init_op, feed_dict={count_placeholder: -1})
- for _ in range(17):
- results = self.evaluate(get_next)
- for component, result_component in zip(components, results):
- self.assertAllEqual(component, result_component)
-
- def testTakeTensorDataset(self):
- components = (np.arange(10),)
- count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = (dataset_ops.Dataset.from_tensor_slices(components)
- .take(count_placeholder).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([c.shape[1:] for c in components],
- [t.shape for t in get_next])
-
- with self.cached_session() as sess:
- # Take fewer than input size
- sess.run(init_op, feed_dict={count_placeholder: 4})
- for i in range(4):
- results = self.evaluate(get_next)
- self.assertAllEqual(results, components[0][i:i+1])
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Take more than input size
- sess.run(init_op, feed_dict={count_placeholder: 25})
- for i in range(10):
- results = self.evaluate(get_next)
- self.assertAllEqual(results, components[0][i:i+1])
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Take all of input
- sess.run(init_op, feed_dict={count_placeholder: -1})
- for i in range(10):
- results = self.evaluate(get_next)
- self.assertAllEqual(results, components[0][i:i+1])
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Take nothing
- sess.run(init_op, feed_dict={count_placeholder: 0})
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSkipTensorDataset(self):
- components = (np.arange(10),)
- count_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = (dataset_ops.Dataset.from_tensor_slices(components)
- .skip(count_placeholder).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([c.shape[1:] for c in components],
- [t.shape for t in get_next])
-
- with self.cached_session() as sess:
- # Skip fewer than input size, we should skip
- # the first 4 elements and then read the rest.
- sess.run(init_op, feed_dict={count_placeholder: 4})
- for i in range(4, 10):
- results = self.evaluate(get_next)
- self.assertAllEqual(results, components[0][i:i+1])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Skip more than input size: get nothing.
- sess.run(init_op, feed_dict={count_placeholder: 25})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Skip exactly input size.
- sess.run(init_op, feed_dict={count_placeholder: 10})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Set -1 for 'count': skip the entire dataset.
- sess.run(init_op, feed_dict={count_placeholder: -1})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Skip nothing
- sess.run(init_op, feed_dict={count_placeholder: 0})
- for i in range(0, 10):
- results = self.evaluate(get_next)
- self.assertAllEqual(results, components[0][i:i+1])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testRepeatRepeatTensorDataset(self):
- """Test the composition of repeat datasets."""
- components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
- inner_count = array_ops.placeholder(dtypes.int64, shape=[])
- outer_count = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = (dataset_ops.Dataset.from_tensors(components).repeat(inner_count)
- .repeat(outer_count).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([c.shape for c in components],
- [t.shape for t in get_next])
-
- with self.cached_session() as sess:
- sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14})
- for _ in range(7 * 14):
- results = self.evaluate(get_next)
- for component, result_component in zip(components, results):
- self.assertAllEqual(component, result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testRepeatEmptyDataset(self):
- """Test that repeating an empty dataset does not hang."""
- iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10)
- .repeat(-1).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
deleted file mode 100644
index b9f3c79..0000000
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class ShardDatasetOpTest(test_base.DatasetTestBase):
-
- def testSimpleCase(self):
- dataset = dataset_ops.Dataset.range(10).shard(5, 2)
- iterator = dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- self.assertEqual(2, sess.run(iterator.get_next()))
- self.assertEqual(7, sess.run(iterator.get_next()))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
- def testNestedData(self):
- dataset_a = dataset_ops.Dataset.range(10)
- dataset_b = dataset_ops.Dataset.range(10, 0, -1)
- dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
- iterator = dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- self.assertEqual((2, 8), sess.run(iterator.get_next()))
- self.assertEqual((7, 3), sess.run(iterator.get_next()))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
- def testOffsetZero(self):
- dataset = dataset_ops.Dataset.range(10).shard(5, 0)
- iterator = dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- self.assertEqual(0, sess.run(iterator.get_next()))
- self.assertEqual(5, sess.run(iterator.get_next()))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
- def testOffsetGreaterNumShards(self):
- with self.assertRaises(ValueError):
- dataset_ops.Dataset.range(10).shard(5, 7)
-
- def testNegativeOffset(self):
- with self.assertRaises(ValueError):
- dataset_ops.Dataset.range(10).shard(5, -3)
-
- def testNegativeNumShards(self):
- with self.assertRaises(ValueError):
- dataset_ops.Dataset.range(10).shard(-3, 1)
-
- def testZeroNumShards(self):
- with self.assertRaises(ValueError):
- dataset_ops.Dataset.range(10).shard(0, 1)
-
- def testIteratorEndsBeforeFirstElem(self):
- dataset = dataset_ops.Dataset.range(1).shard(5, 2)
- iterator = dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
- def testLargerWorkerPool(self):
- dataset = dataset_ops.Dataset.range(10).shard(7, 5)
- iterator = dataset.make_one_shot_iterator()
- with self.cached_session() as sess:
- self.assertEqual(5, sess.run(iterator.get_next()))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
- def testIndexEqualsNumShards(self):
- dataset = dataset_ops.Dataset.range(10).shard(5, 4)
- iterator = dataset.make_one_shot_iterator()
- with self.cached_session() as sess:
- self.assertEqual(4, sess.run(iterator.get_next()))
- self.assertEqual(9, sess.run(iterator.get_next()))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
- def testIndexEqualsNumShards2(self):
- dataset = dataset_ops.Dataset.range(10).shard(4, 3)
- iterator = dataset.make_one_shot_iterator()
- with self.cached_session() as sess:
- self.assertEqual(3, sess.run(iterator.get_next()))
- self.assertEqual(7, sess.run(iterator.get_next()))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(iterator.get_next())
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/shard_test.py b/tensorflow/python/data/kernel_tests/shard_test.py
new file mode 100644
index 0000000..9285506
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/shard_test.py
@@ -0,0 +1,76 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.shard()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class ShardTest(test_base.DatasetTestBase):
+
+ def testSimpleCase(self):
+ dataset = dataset_ops.Dataset.range(10).shard(5, 2)
+ self.assertDatasetProduces(dataset, expected_output=[2, 7])
+
+ def testNestedData(self):
+ dataset_a = dataset_ops.Dataset.range(10)
+ dataset_b = dataset_ops.Dataset.range(10, 0, -1)
+ dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
+ self.assertDatasetProduces(dataset, expected_output=[(2, 8), (7, 3)])
+
+ def testOffsetZero(self):
+ dataset = dataset_ops.Dataset.range(10).shard(5, 0)
+ self.assertDatasetProduces(dataset, expected_output=[0, 5])
+
+ def testOffsetGreaterNumShards(self):
+ with self.assertRaises(ValueError):
+ dataset_ops.Dataset.range(10).shard(5, 7)
+
+ def testNegativeOffset(self):
+ with self.assertRaises(ValueError):
+ dataset_ops.Dataset.range(10).shard(5, -3)
+
+ def testNegativeNumShards(self):
+ with self.assertRaises(ValueError):
+ dataset_ops.Dataset.range(10).shard(-3, 1)
+
+ def testZeroNumShards(self):
+ with self.assertRaises(ValueError):
+ dataset_ops.Dataset.range(10).shard(0, 1)
+
+ def testIteratorEndsBeforeFirstElem(self):
+ dataset = dataset_ops.Dataset.range(1).shard(5, 2)
+ self.assertDatasetProduces(dataset, expected_output=[])
+
+ def testLargerWorkerPool(self):
+ dataset = dataset_ops.Dataset.range(10).shard(7, 5)
+ self.assertDatasetProduces(dataset, expected_output=[5])
+
+ def testIndexEqualsNumShards(self):
+ dataset = dataset_ops.Dataset.range(10).shard(5, 4)
+ self.assertDatasetProduces(dataset, expected_output=[4, 9])
+
+ def testIndexEqualsNumShards2(self):
+ dataset = dataset_ops.Dataset.range(10).shard(4, 3)
+ self.assertDatasetProduces(dataset, expected_output=[3, 7])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
deleted file mode 100644
index 990f4f2..0000000
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ /dev/null
@@ -1,278 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import collections
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class ShuffleDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def testShuffleDataset(self):
- components = (
- np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
- np.array([9.0, 10.0, 11.0, 12.0])
- )
- count_placeholder = array_ops.placeholder_with_default(
- constant_op.constant(5, dtypes.int64), shape=[])
- buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components)
- .repeat(count_placeholder))
-
- shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder,
- seed_placeholder)
-
- self.assertEqual(tuple([c.shape[1:] for c in components]),
- shuffle_dataset.output_shapes)
-
- # Create initialization ops for iterators without and with
- # shuffling, respectively.
- iterator = iterator_ops.Iterator.from_structure(
- shuffle_dataset.output_types, shuffle_dataset.output_shapes)
- init_fifo_op = iterator.make_initializer(repeat_dataset)
- init_shuffle_op = iterator.make_initializer(shuffle_dataset)
-
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # First run without shuffling to collect the "ground truth".
- self.evaluate(init_fifo_op)
- unshuffled_elements = []
- for _ in range(20):
- unshuffled_elements.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Assert that the shuffled dataset has the same elements as the
- # "ground truth".
- sess.run(
- init_shuffle_op,
- feed_dict={buffer_size_placeholder: 100,
- seed_placeholder: 37})
- shuffled_elements = []
- for _ in range(20):
- shuffled_elements.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- self.assertAllEqual(
- sorted(unshuffled_elements), sorted(shuffled_elements))
-
- # Assert that shuffling twice with the same seeds gives the same sequence.
- sess.run(
- init_shuffle_op,
- feed_dict={buffer_size_placeholder: 100,
- seed_placeholder: 37})
- reshuffled_elements_same_seed = []
- for _ in range(20):
- reshuffled_elements_same_seed.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)
-
- # Assert that shuffling twice with a different seed gives a different
- # permutation of the same elements.
- sess.run(
- init_shuffle_op,
- feed_dict={buffer_size_placeholder: 100,
- seed_placeholder: 1037})
- reshuffled_elements_different_seed = []
- for _ in range(20):
- reshuffled_elements_different_seed.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
- self.assertAllEqual(
- sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))
-
- # Assert that the shuffled dataset has the same elements as the
- # "ground truth" when the buffer size is smaller than the input
- # dataset.
- sess.run(
- init_shuffle_op,
- feed_dict={buffer_size_placeholder: 2,
- seed_placeholder: 37})
- reshuffled_elements_small_buffer = []
- for _ in range(20):
- reshuffled_elements_small_buffer.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- self.assertAllEqual(
- sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))
-
- # Test the case of shuffling an empty dataset.
- sess.run(init_shuffle_op, feed_dict={buffer_size_placeholder: 2,
- seed_placeholder: 37,
- count_placeholder: 0})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSeedZero(self):
- """Test for same behavior when the seed is a Python or Tensor zero."""
- iterator = (
- dataset_ops.Dataset.range(10).shuffle(10, seed=0)
- .make_one_shot_iterator())
- get_next = iterator.get_next()
-
- elems = []
- with self.cached_session() as sess:
- for _ in range(10):
- elems.append(sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder)
- .make_initializable_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
- for elem in elems:
- self.assertEqual(elem, self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDefaultArguments(self):
- components = [0, 1, 2, 3, 4]
- iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5)
- .repeat().make_one_shot_iterator())
-
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- counts = collections.defaultdict(lambda: 0)
- for _ in range(10):
- for _ in range(5):
- counts[sess.run(get_next)] += 1
-
- for i in range(5):
- self.assertEqual(10, counts[i])
-
- def testShuffleNoReshuffleEachIteration(self):
- iterator = (dataset_ops.Dataset.range(10)
- .shuffle(10, reshuffle_each_iteration=False)
- .batch(10)
- .repeat(3)
- .make_one_shot_iterator())
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- initial_permutation = self.evaluate(next_element)
- self.assertAllEqual(initial_permutation, self.evaluate(next_element))
- self.assertAllEqual(initial_permutation, self.evaluate(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testShuffleReshuffleEachIteration(self):
- iterator = (dataset_ops.Dataset.range(10)
- .shuffle(10, seed=3, reshuffle_each_iteration=True)
- .batch(10)
- .repeat(3)
- .make_one_shot_iterator())
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- initial_permutation = list(sess.run(next_element))
- for _ in range(2):
- next_permutation = list(sess.run(next_element))
- self.assertNotEqual(initial_permutation, next_permutation)
- self.assertAllEqual(
- sorted(initial_permutation), sorted(next_permutation))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- @parameterized.named_parameters(
- ("ReshuffleGraphLevelSeed", True, 38, None),
- ("ReshuffleOpLevelSeed", True, None, 42),
- ("ReshuffleGraphAndOpLevelSeed", True, 38, 42),
- ("NoReshuffleGraphLevelSeed", False, 38, None),
- ("NoReshuffleOpLevelSeed", False, None, 42),
- ("NoReshuffleGraphAndOpLevelSeed", False, 38, 42),
- )
- def testShuffleSeed(self, reshuffle, graph_level_seed, op_level_seed):
- results = []
- for _ in range(2):
- with ops.Graph().as_default() as g:
- random_seed.set_random_seed(graph_level_seed)
- dataset = dataset_ops.Dataset.range(10).shuffle(
- 10, seed=op_level_seed, reshuffle_each_iteration=reshuffle).repeat(
- 3)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- run_results = []
- with self.session(graph=g) as sess:
- for _ in range(30):
- run_results.append(sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- results.append(run_results)
-
- self.assertAllEqual(results[0], results[1])
-
- @parameterized.named_parameters(
- ("ReshuffleOneShot", True, False),
- ("ReshuffleInitializable", True, True),
- ("NoReshuffleOneShot", False, False),
- ("NoReshuffleInitializable", False, True),
- )
- def testMultipleIterators(self, reshuffle, initializable):
- with ops.Graph().as_default() as g:
- dataset = dataset_ops.Dataset.range(100).shuffle(
- 10, reshuffle_each_iteration=reshuffle).repeat(3)
-
- if initializable:
- iterators = [dataset.make_initializable_iterator() for _ in range(2)]
- else:
- iterators = [dataset.make_one_shot_iterator() for _ in range(2)]
-
- results = []
- with self.session(graph=g) as sess:
- for iterator in iterators:
- if initializable:
- self.evaluate(iterator.initializer)
- next_element = iterator.get_next()
- run_results = []
- for _ in range(300):
- run_results.append(sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- results.append(run_results)
-
- self.assertNotEqual(results[0], results[1])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py
new file mode 100644
index 0000000..49460a1
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/shuffle_test.py
@@ -0,0 +1,248 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.shuffle()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testShuffleDataset(self):
+ components = (
+ np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]),
+ np.array([9.0, 10.0, 11.0, 12.0])
+ )
+
+ def dataset_fn(count=5, buffer_size=None, seed=0):
+ repeat_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(count))
+ if buffer_size:
+ shuffle_dataset = repeat_dataset.shuffle(buffer_size, seed)
+
+ self.assertEqual(
+ tuple([c.shape[1:] for c in components]),
+ shuffle_dataset.output_shapes)
+ return shuffle_dataset
+ else:
+ return repeat_dataset
+
+ # First run without shuffling to collect the "ground truth".
+ get_next = self.getNext(dataset_fn())
+ unshuffled_elements = []
+ for _ in range(20):
+ unshuffled_elements.append(self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ # Assert that the shuffled dataset has the same elements as the
+ # "ground truth".
+ get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
+ shuffled_elements = []
+ for _ in range(20):
+ shuffled_elements.append(self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ self.assertAllEqual(sorted(unshuffled_elements), sorted(shuffled_elements))
+
+ # Assert that shuffling twice with the same seeds gives the same sequence.
+ get_next = self.getNext(dataset_fn(buffer_size=100, seed=37))
+ reshuffled_elements_same_seed = []
+ for _ in range(20):
+ reshuffled_elements_same_seed.append(self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ self.assertEqual(shuffled_elements, reshuffled_elements_same_seed)
+
+ # Assert that shuffling twice with a different seed gives a different
+ # permutation of the same elements.
+ get_next = self.getNext(dataset_fn(buffer_size=100, seed=137))
+ reshuffled_elements_different_seed = []
+ for _ in range(20):
+ reshuffled_elements_different_seed.append(self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed)
+ self.assertAllEqual(
+ sorted(shuffled_elements), sorted(reshuffled_elements_different_seed))
+
+ # Assert that the shuffled dataset has the same elements as the
+ # "ground truth" when the buffer size is smaller than the input
+ # dataset.
+ get_next = self.getNext(dataset_fn(buffer_size=2, seed=37))
+ reshuffled_elements_small_buffer = []
+ for _ in range(20):
+ reshuffled_elements_small_buffer.append(self.evaluate(get_next()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ self.assertAllEqual(
+ sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer))
+
+ # Test the case of shuffling an empty dataset.
+ get_next = self.getNext(dataset_fn(count=0, buffer_size=100, seed=37))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ def testSkipEagerSeedZero(self):
+ """Test for same behavior when the seed is a Python or Tensor zero."""
+ iterator = (
+ dataset_ops.Dataset.range(10).shuffle(10, seed=0)
+ .make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ elems = []
+ with self.cached_session() as sess:
+ for _ in range(10):
+ elems.append(sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = (
+ dataset_ops.Dataset.range(10).shuffle(10, seed=seed_placeholder)
+ .make_initializable_iterator())
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={seed_placeholder: 0})
+ for elem in elems:
+ self.assertEqual(elem, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDefaultArguments(self):
+ components = [0, 1, 2, 3, 4]
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).shuffle(
+ 5).repeat()
+ get_next = self.getNext(dataset)
+ counts = collections.defaultdict(lambda: 0)
+ for _ in range(10):
+ for _ in range(5):
+ counts[self.evaluate(get_next())] += 1
+
+ for i in range(5):
+ self.assertEqual(10, counts[i])
+
+ def testShuffleNoReshuffleEachIteration(self):
+ dataset = dataset_ops.Dataset.range(10).shuffle(
+ 10, reshuffle_each_iteration=False).batch(10).repeat(3)
+ next_element = self.getNext(dataset)
+
+ initial_permutation = self.evaluate(next_element())
+ self.assertAllEqual(initial_permutation, self.evaluate(next_element()))
+ self.assertAllEqual(initial_permutation, self.evaluate(next_element()))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next_element())
+
+ def testShuffleReshuffleEachIteration(self):
+ dataset = dataset_ops.Dataset.range(10).shuffle(
+ 10, seed=3, reshuffle_each_iteration=True).batch(10).repeat(3)
+ next_element = self.getNext(dataset)
+
+ initial_permutation = list(self.evaluate(next_element()))
+ for _ in range(2):
+ next_permutation = list(self.evaluate(next_element()))
+ self.assertNotEqual(initial_permutation, next_permutation)
+ self.assertAllEqual(sorted(initial_permutation), sorted(next_permutation))
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next_element())
+
+ @parameterized.named_parameters(
+ ("ReshuffleGraphLevelSeed", True, 38, None),
+ ("ReshuffleOpLevelSeed", True, None, 42),
+ ("ReshuffleGraphAndOpLevelSeed", True, 38, 42),
+ ("NoReshuffleGraphLevelSeed", False, 38, None),
+ ("NoReshuffleOpLevelSeed", False, None, 42),
+ ("NoReshuffleGraphAndOpLevelSeed", False, 38, 42),
+ )
+ def testSkipEagerShuffleSeed(self, reshuffle, graph_level_seed,
+ op_level_seed):
+ results = []
+ for _ in range(2):
+ with ops.Graph().as_default() as g:
+ random_seed.set_random_seed(graph_level_seed)
+ dataset = dataset_ops.Dataset.range(10).shuffle(
+ 10, seed=op_level_seed, reshuffle_each_iteration=reshuffle).repeat(
+ 3)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ run_results = []
+ with self.session(graph=g) as sess:
+ for _ in range(30):
+ run_results.append(sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ results.append(run_results)
+
+ self.assertAllEqual(results[0], results[1])
+
+ # TODO(b/117581999): fails for eager mode with result[0] equal to result[1],
+ # debug.
+ @parameterized.named_parameters(
+ ("ReshuffleOneShot", True, False),
+ ("ReshuffleInitializable", True, True),
+ ("NoReshuffleOneShot", False, False),
+ ("NoReshuffleInitializable", False, True),
+ )
+ def testSkipEagerMultipleIterators(self, reshuffle, initializable):
+ with ops.Graph().as_default() as g:
+ dataset = dataset_ops.Dataset.range(100).shuffle(
+ 10, reshuffle_each_iteration=reshuffle).repeat(3)
+
+ if initializable:
+ iterators = [dataset.make_initializable_iterator() for _ in range(2)]
+ else:
+ iterators = [dataset.make_one_shot_iterator() for _ in range(2)]
+
+ results = []
+ with self.session(graph=g) as sess:
+ for iterator in iterators:
+ if initializable:
+ sess.run(iterator.initializer)
+ next_element = iterator.get_next()
+ run_results = []
+ for _ in range(300):
+ run_results.append(sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ results.append(run_results)
+
+ self.assertNotEqual(results[0], results[1])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/skip_test.py b/tensorflow/python/data/kernel_tests/skip_test.py
new file mode 100644
index 0000000..c22be57
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/skip_test.py
@@ -0,0 +1,62 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.skip()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class SkipTest(test_base.DatasetTestBase):
+
+ def testSkipTensorDataset(self):
+ components = (np.arange(10),)
+
+ def do_test(count):
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).skip(count)
+ self.assertEqual([c.shape[1:] for c in components],
+ [shape for shape in dataset.output_shapes])
+ start_range = min(count, 10) if count != -1 else 10
+ self.assertDatasetProduces(
+ dataset,
+ [tuple(components[0][i:i + 1]) for i in range(start_range, 10)])
+
+ # Skip fewer than input size, we should skip
+ # the first 4 elements and then read the rest.
+ do_test(4)
+
+ # Skip more than input size: get nothing.
+ do_test(25)
+
+ # Skip exactly input size.
+ do_test(10)
+
+ # Set -1 for 'count': skip the entire dataset.
+ do_test(-1)
+
+ # Skip nothing
+ do_test(0)
+
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/take_test.py b/tensorflow/python/data/kernel_tests/take_test.py
new file mode 100644
index 0000000..03a7ece
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/take_test.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.take()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class TakeTest(test_base.DatasetTestBase):
+
+ def testTakeTensorDataset(self):
+ components = (np.arange(10),)
+
+ def do_test(count):
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).take(count)
+ self.assertEqual([c.shape[1:] for c in components],
+ [shape for shape in dataset.output_shapes])
+ num_output = min(count, 10) if count != -1 else 10
+ self.assertDatasetProduces(
+ dataset, [tuple(components[0][i:i + 1]) for i in range(num_output)])
+
+ # Take fewer than input size
+ do_test(4)
+
+ # Take more than input size
+ do_test(25)
+
+ # Take all of input
+ do_test(-1)
+
+ # Take nothing
+ do_test(0)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index ca853da..03fc0da 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -45,8 +45,8 @@
```python
# In both graph and eager modes
dataset = ...
- nxt = self.getNext(dataset)
- result = self.evaluate(nxt())
+ get_next = self.getNext(dataset)
+ result = self.evaluate(get_next())
```
Args:
@@ -66,21 +66,32 @@
self.evaluate(iterator.initializer)
else:
iterator = dataset.make_one_shot_iterator()
- return iterator.get_next
+ get_next = iterator.get_next()
+ return lambda: get_next
- def _compareOutputToExpected(self, result_values, expected_values):
+ def _compareOutputToExpected(self, result_values, expected_values,
+ assert_items_equal):
+ if assert_items_equal:
+ # TODO(shivaniagrawal): add support for nested elements containing sparse
+ # tensors when needed.
+ self.assertItemsEqual(result_values, expected_values)
+ return
for i in range(len(result_values)):
- if sparse_tensor.is_sparse(result_values[i]):
- self.assertSparseValuesEqual(result_values[i], expected_values[i])
- else:
- self.assertAllEqual(result_values[i], expected_values[i])
+ nest.assert_same_structure(result_values[i], expected_values[i])
+ for result_value, expected_value in zip(
+ nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
+ if sparse_tensor.is_sparse(result_value):
+ self.assertSparseValuesEqual(result_value, expected_value)
+ else:
+ self.assertAllEqual(result_value, expected_value)
def assertDatasetProduces(self,
dataset,
expected_output=None,
expected_error=None,
requires_initialization=False,
- num_test_iterations=2):
+ num_test_iterations=1,
+ assert_items_equal=False):
"""Asserts that a dataset produces the expected output / error.
Args:
@@ -98,6 +109,8 @@
dataset (e.g. when it contains stateful nodes). Defaults to False.
num_test_iterations: Number of times `dataset` will be iterated. Defaults
to 2.
+ assert_items_equal: Tests expected_output has (only) the same elements
+ regardless of order.
"""
self.assertTrue(
expected_error is not None or expected_output is not None,
@@ -120,7 +133,7 @@
result = []
for _ in range(len(expected_output)):
result.append(self.evaluate(get_next()))
- self._compareOutputToExpected(result, expected_output)
+ self._compareOutputToExpected(result, expected_output, assert_items_equal)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/python/data/kernel_tests/text_line_dataset_test.py b/tensorflow/python/data/kernel_tests/text_line_dataset_test.py
new file mode 100644
index 0000000..4db09a9
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/text_line_dataset_test.py
@@ -0,0 +1,165 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.TextLineDataset`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import readers
+from tensorflow.python.eager import context
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+try:
+ import psutil # pylint: disable=g-import-not-at-top
+ psutil_import_succeeded = True
+except ImportError:
+ psutil_import_succeeded = False
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class TextLineDatasetTest(test_base.DatasetTestBase):
+
+ def _lineText(self, f, l):
+ return compat.as_bytes("%d: %d" % (f, l))
+
+ def _createFiles(self,
+ num_files,
+ num_lines,
+ crlf=False,
+ compression_type=None):
+ filenames = []
+ for i in range(num_files):
+ fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
+ filenames.append(fn)
+ contents = []
+ for j in range(num_lines):
+ contents.append(self._lineText(i, j))
+ # Always include a newline after the record unless it is
+ # at the end of the file, in which case we include it
+ if j + 1 != num_lines or i == 0:
+ contents.append(b"\r\n" if crlf else b"\n")
+ contents = b"".join(contents)
+
+ if not compression_type:
+ with open(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "GZIP":
+ with gzip.GzipFile(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "ZLIB":
+ contents = zlib.compress(contents)
+ with open(fn, "wb") as f:
+ f.write(contents)
+ else:
+ raise ValueError("Unsupported compression_type", compression_type)
+
+ return filenames
+
+ def _testTextLineDataset(self, compression_type=None):
+ test_filenames = self._createFiles(
+ 2, 5, crlf=True, compression_type=compression_type)
+
+ def dataset_fn(filenames, num_epochs, batch_size=None):
+ repeat_dataset = readers.TextLineDataset(
+ filenames, compression_type=compression_type).repeat(num_epochs)
+ if batch_size:
+ return repeat_dataset.batch(batch_size)
+ return repeat_dataset
+
+ # Basic test: read from file 0.
+ expected_output = [self._lineText(0, i) for i in range(5)]
+ self.assertDatasetProduces(
+ dataset_fn([test_filenames[0]], 1), expected_output=expected_output)
+
+ # Basic test: read from file 1.
+ self.assertDatasetProduces(
+ dataset_fn([test_filenames[1]], 1),
+ expected_output=[self._lineText(1, i) for i in range(5)])
+
+ # Basic test: read from both files.
+ expected_output = [self._lineText(0, i) for i in range(5)]
+ expected_output.extend([self._lineText(1, i) for i in range(5)])
+ self.assertDatasetProduces(
+ dataset_fn(test_filenames, 1), expected_output=expected_output)
+
+ # Test repeated iteration through both files.
+ expected_output = [self._lineText(0, i) for i in range(5)]
+ expected_output.extend([self._lineText(1, i) for i in range(5)])
+ self.assertDatasetProduces(
+ dataset_fn(test_filenames, 10), expected_output=expected_output * 10)
+
+ # Test batched and repeated iteration through both files.
+ self.assertDatasetProduces(
+ dataset_fn(test_filenames, 10, 5),
+ expected_output=[[self._lineText(0, i) for i in range(5)],
+ [self._lineText(1, i) for i in range(5)]] * 10)
+
+ def testTextLineDatasetNoCompression(self):
+ self._testTextLineDataset()
+
+ def testTextLineDatasetGzipCompression(self):
+ self._testTextLineDataset(compression_type="GZIP")
+
+ def testTextLineDatasetZlibCompression(self):
+ self._testTextLineDataset(compression_type="ZLIB")
+
+ def testTextLineDatasetBuffering(self):
+ test_filenames = self._createFiles(2, 5, crlf=True)
+
+ repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10)
+ expected_output = []
+ for j in range(2):
+ expected_output.extend([self._lineText(j, i) for i in range(5)])
+ self.assertDatasetProduces(repeat_dataset, expected_output=expected_output)
+
+ def testIteratorResourceCleanup(self):
+ filename = os.path.join(self.get_temp_dir(), "text.txt")
+ with open(filename, "wt") as f:
+ for i in range(3):
+ f.write("%d\n" % (i,))
+ with context.eager_mode():
+ first_iterator = iter(readers.TextLineDataset(filename))
+ self.assertEqual(b"0", next(first_iterator).numpy())
+ second_iterator = iter(readers.TextLineDataset(filename))
+ self.assertEqual(b"0", next(second_iterator).numpy())
+ # Eager kernel caching is based on op attributes, which includes the
+ # Dataset's output shape. Create a different kernel to test that they
+ # don't create resources with the same names.
+ different_kernel_iterator = iter(
+ readers.TextLineDataset(filename).repeat().batch(16))
+ self.assertEqual([16], next(different_kernel_iterator).shape)
+ # Remove our references to the Python Iterator objects, which (assuming no
+ # reference cycles) is enough to trigger DestroyResourceOp and close the
+ # partially-read files.
+ del first_iterator
+ del second_iterator
+ del different_kernel_iterator
+ if not psutil_import_succeeded:
+ self.skipTest(
+ "psutil is required to check that we've closed our files.")
+ open_files = psutil.Process().open_files()
+ self.assertNotIn(filename, [open_file.path for open_file in open_files])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py b/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py
new file mode 100644
index 0000000..13a70aa
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/tf_record_dataset_test.py
@@ -0,0 +1,170 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.TFRecordDataset`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import test_util
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class TFRecordDatasetTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ super(TFRecordDatasetTest, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ self.test_filenames = self._createFiles()
+
+ def dataset_fn(self,
+ filenames,
+ compression_type="",
+ num_epochs=1,
+ batch_size=None):
+
+ repeat_dataset = readers.TFRecordDataset(
+ filenames, compression_type).repeat(num_epochs)
+ if batch_size:
+ return repeat_dataset.batch(batch_size)
+ return repeat_dataset
+
+ def _record(self, f, r):
+ return compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._record(i, j))
+ writer.close()
+ return filenames
+
+ def testReadOneEpoch(self):
+ # Basic test: read from file 0.
+ dataset = self.dataset_fn(self.test_filenames[0])
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[self._record(0, i) for i in range(self._num_records)])
+
+ # Basic test: read from file 1.
+ dataset = self.dataset_fn(self.test_filenames[1])
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=[self._record(1, i) for i in range(self._num_records)])
+
+ # Basic test: read from both files.
+ dataset = self.dataset_fn(self.test_filenames)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testReadTenEpochs(self):
+ dataset = self.dataset_fn(self.test_filenames, num_epochs=10)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ self.assertDatasetProduces(dataset, expected_output=expected_output * 10)
+
+ def testReadTenEpochsOfBatches(self):
+ dataset = self.dataset_fn(
+ self.test_filenames, num_epochs=10, batch_size=self._num_records)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.append(
+ [self._record(j, i) for i in range(self._num_records)])
+ self.assertDatasetProduces(dataset, expected_output=expected_output * 10)
+
+ def testReadZlibFiles(self):
+ zlib_files = []
+ for i, fn in enumerate(self.test_filenames):
+ with open(fn, "rb") as f:
+ cdata = zlib.compress(f.read())
+
+ zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+ zlib_files.append(zfn)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ dataset = self.dataset_fn(zlib_files, compression_type="ZLIB")
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testReadGzipFiles(self):
+ gzip_files = []
+ for i, fn in enumerate(self.test_filenames):
+ with open(fn, "rb") as f:
+ gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
+ with gzip.GzipFile(gzfn, "wb") as gzf:
+ gzf.write(f.read())
+ gzip_files.append(gzfn)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ dataset = self.dataset_fn(gzip_files, compression_type="GZIP")
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testReadWithBuffer(self):
+ one_mebibyte = 2**20
+ dataset = readers.TFRecordDataset(
+ self.test_filenames, buffer_size=one_mebibyte)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testReadFromDatasetOfFiles(self):
+ files = dataset_ops.Dataset.from_tensor_slices(self.test_filenames)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ dataset = readers.TFRecordDataset(files)
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testReadTenEpochsFromDatasetOfFilesInParallel(self):
+ files = dataset_ops.Dataset.from_tensor_slices(
+ self.test_filenames).repeat(10)
+ expected_output = []
+ for j in range(self._num_files):
+ expected_output.extend(
+ [self._record(j, i) for i in range(self._num_records)])
+ dataset = readers.TFRecordDataset(files, num_parallel_reads=4)
+ self.assertDatasetProduces(
+ dataset, expected_output=expected_output * 10, assert_items_equal=True)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
deleted file mode 100644
index 35adcdd..0000000
--- a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
+++ /dev/null
@@ -1,291 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @parameterized.named_parameters(
- ("1", 20, 14, 7, 1),
- ("2", 20, 17, 9, 1),
- ("3", 20, 14, 14, 1),
- ("4", 20, 10, 14, 1),
- ("5", 20, 14, 19, 1),
- ("6", 20, 4, 1, 2),
- ("7", 20, 2, 1, 6),
- ("8", 20, 4, 7, 2),
- ("9", 20, 2, 7, 6),
- ("10", 1, 10, 4, 1),
- ("11", 0, 10, 4, 1),
- ("12", 20, 14, 7, 1, False),
- ("13", 20, 17, 9, 1, False),
- ("14", 20, 14, 14, 1, False),
- ("15", 20, 10, 14, 1, False),
- ("16", 20, 14, 19, 1, False),
- ("17", 20, 4, 1, 2, False),
- ("18", 20, 2, 1, 6, False),
- ("19", 20, 4, 7, 2, False),
- ("20", 20, 2, 7, 6, False),
- ("21", 1, 10, 4, 1, False),
- ("22", 0, 10, 4, 1, False),
- )
- def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
- """Tests a dataset that slides a window its input elements."""
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- count_t = array_ops.placeholder(dtypes.int64, shape=[])
- size_t = array_ops.placeholder(dtypes.int64, shape=[])
- shift_t = array_ops.placeholder(dtypes.int64, shape=[])
- stride_t = array_ops.placeholder(dtypes.int64, shape=[])
- drop_remainder_t = array_ops.placeholder(dtypes.bool, shape=[])
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- def _flat_map_fn(x, y, z):
- return dataset_ops.Dataset.zip((x.batch(batch_size=size_t),
- y.batch(batch_size=size_t),
- z.batch(batch_size=size_t)))
-
- iterator = dataset_ops.Dataset.from_tensor_slices(components).map(
- _map_fn).repeat(count).window(
- size=size_t,
- shift=shift_t,
- stride=stride_t,
- drop_remainder=drop_remainder_t).flat_map(
- _flat_map_fn).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([[None] + list(c.shape[1:]) for c in components],
- [t.shape.as_list() for t in get_next])
-
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- count_t: count,
- size_t: size,
- shift_t: shift,
- stride_t: stride,
- drop_remainder_t: drop_remainder
- })
- num_full_batches = max(
- 0, (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
- for i in range(num_full_batches):
- result = self.evaluate(get_next)
- for component, result_component in zip(components, result):
- for j in range(size):
- self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
- result_component[j])
- if not drop_remainder:
- num_partial_batches = (count * 7) // shift + (
- (count * 7) % shift > 0) - num_full_batches
- for i in range(num_partial_batches):
- result = self.evaluate(get_next)
- for component, result_component in zip(components, result):
- remaining = (count * 7) - ((num_full_batches + i) * shift)
- num_elements = remaining // stride + ((remaining % stride) > 0)
- for j in range(num_elements):
- self.assertAllEqual(
- component[((num_full_batches + i) * shift + j * stride) % 7]
- **2, result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", 14, 0, 3, 1),
- ("2", 14, 3, 0, 1),
- ("3", 14, 3, 3, 0),
- )
- def testWindowDatasetInvalid(self, count, size, shift, stride):
- count_t = array_ops.placeholder(dtypes.int64, shape=[])
- size_t = array_ops.placeholder(dtypes.int64, shape=[])
- shift_t = array_ops.placeholder(dtypes.int64, shape=[])
- stride_t = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(
- count_t).window(
- size=size_t, shift=shift_t,
- stride=stride_t).flat_map(lambda x: x.batch(batch_size=size_t)
- ).make_initializable_iterator()
- init_op = iterator.initializer
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(
- init_op,
- feed_dict={
- count_t: count,
- size_t: size,
- shift_t: shift,
- stride_t: stride
- })
-
- def testWindowSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
- size=5, shift=3, drop_remainder=True).flat_map(
- lambda x: x.batch(batch_size=5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- num_batches = (10 - 5) // 3 + 1
- for i in range(num_batches):
- actual = self.evaluate(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testWindowSparseWithDifferentDenseShapes(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=array_ops.expand_dims(
- math_ops.range(i, dtype=dtypes.int64), 1),
- values=array_ops.fill([math_ops.to_int32(i)], i),
- dense_shape=[i])
-
- iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
- size=5, shift=3, drop_remainder=True).flat_map(
- lambda x: x.batch(batch_size=5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- num_batches = (10 - 5) // 3 + 1
- for i in range(num_batches):
- actual = self.evaluate(get_next)
- expected_indices = []
- expected_values = []
- for j in range(5):
- for k in range(i * 3 + j):
- expected_indices.append([j, k])
- expected_values.append(i * 3 + j)
- expected = sparse_tensor.SparseTensorValue(
- indices=expected_indices,
- values=expected_values,
- dense_shape=[5, i * 3 + 5 - 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testNestedWindowSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).map(_sparse).window(
- size=4, shift=2,
- drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window(
- size=3, shift=1, drop_remainder=True).flat_map(
- lambda x: x.batch(batch_size=3)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(init_op)
- # Slide: 1st batch.
- actual = self.evaluate(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
- [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
- [2, 2, 0], [2, 3, 0]],
- values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
- dense_shape=[3, 4, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- # Slide: 2nd batch.
- actual = self.evaluate(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
- [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
- [2, 2, 0], [2, 3, 0]],
- values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
- dense_shape=[3, 4, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testWindowShapeError(self):
-
- def generator():
- yield [1.0, 2.0, 3.0]
- yield [4.0, 5.0, 6.0]
- yield [7.0, 8.0, 9.0, 10.0]
-
- iterator = dataset_ops.Dataset.from_generator(
- generator, dtypes.float32, output_shapes=[None]).window(
- size=3, shift=1).flat_map(
- lambda x: x.batch(batch_size=3)).make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- self.evaluate(iterator.initializer)
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- r"Cannot batch tensors with different shapes in component 0. "
- r"First element had shape \[3\] and element 2 had shape \[4\]."):
- sess.run(next_element)
-
- def testWindowIgnoreErrors(self):
- input_values = np.float32([1., np.nan, 2., np.nan, 3.])
- dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
- lambda x: array_ops.check_numerics(x, "message")).window(
- size=2, shift=2, stride=2,
- drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2))
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- self.assertAllEqual(np.float32([1., 2.]), self.evaluate(get_next))
- self.assertAllEqual(np.float32([2., 3.]), self.evaluate(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/window_test.py b/tensorflow/python/data/kernel_tests/window_test.py
new file mode 100644
index 0000000..d083142
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/window_test.py
@@ -0,0 +1,231 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.window()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
+ ("12", 20, 14, 7, 1, False),
+ ("13", 20, 17, 9, 1, False),
+ ("14", 20, 14, 14, 1, False),
+ ("15", 20, 10, 14, 1, False),
+ ("16", 20, 14, 19, 1, False),
+ ("17", 20, 4, 1, 2, False),
+ ("18", 20, 2, 1, 6, False),
+ ("19", 20, 4, 7, 2, False),
+ ("20", 20, 2, 7, 6, False),
+ ("21", 1, 10, 4, 1, False),
+ ("22", 0, 10, 4, 1, False),
+ )
+ def testWindowDataset(self, count, size, shift, stride, drop_remainder=True):
+ """Tests a dataset that slides a window its input elements."""
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ def _flat_map_fn(x, y, z):
+ return dataset_ops.Dataset.zip((x.batch(batch_size=size),
+ y.batch(batch_size=size),
+ z.batch(batch_size=size)))
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn).repeat(count).window(
+ size=size,
+ shift=shift,
+ stride=stride,
+ drop_remainder=drop_remainder).flat_map(_flat_map_fn)
+ get_next = self.getNext(dataset)
+
+ self.assertEqual(
+ [[None] + list(c.shape[1:]) for c in components],
+ [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
+
+ num_full_batches = max(0,
+ (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
+ for i in range(num_full_batches):
+ result = self.evaluate(get_next())
+ for component, result_component in zip(components, result):
+ for j in range(size):
+ self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
+ result_component[j])
+ if not drop_remainder:
+ num_partial_batches = (count * 7) // shift + (
+ (count * 7) % shift > 0) - num_full_batches
+ for i in range(num_partial_batches):
+ result = self.evaluate(get_next())
+ for component, result_component in zip(components, result):
+ remaining = (count * 7) - ((num_full_batches + i) * shift)
+ num_elements = remaining // stride + ((remaining % stride) > 0)
+ for j in range(num_elements):
+ self.assertAllEqual(
+ component[((num_full_batches + i) * shift + j * stride) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
+ )
+ def testWindowDatasetInvalid(self, count, size, shift, stride):
+ dataset = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(
+ count).window(
+ size=size, shift=shift,
+ stride=stride).flat_map(lambda x: x.batch(batch_size=size))
+ self.assertDatasetProduces(
+ dataset, expected_error=(errors.InvalidArgumentError, ""))
+
+ def testWindowSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=5, shift=3,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5))
+
+ num_batches = (10 - 5) // 3 + 1
+ expected_output = [
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
+ dense_shape=[5, 1]) for i in range(num_batches)
+ ]
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testWindowSparseWithDifferentDenseShapes(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=array_ops.expand_dims(
+ math_ops.range(i, dtype=dtypes.int64), 1),
+ values=array_ops.fill([math_ops.to_int32(i)], i),
+ dense_shape=[i])
+
+ dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=5, shift=3,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5))
+
+ expected_output = []
+ num_batches = (10 - 5) // 3 + 1
+ for i in range(num_batches):
+ expected_indices = []
+ expected_values = []
+ for j in range(5):
+ for k in range(i * 3 + j):
+ expected_indices.append([j, k])
+ expected_values.append(i * 3 + j)
+ expected_output.append(
+ sparse_tensor.SparseTensorValue(
+ indices=expected_indices,
+ values=expected_values,
+ dense_shape=[5, i * 3 + 5 - 1]))
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testNestedWindowSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
+ size=4, shift=2,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window(
+ size=3, shift=1,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=3))
+
+ expected_output = [
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
+ values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
+ dense_shape=[3, 4, 1]),
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
+ [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
+ [2, 2, 0], [2, 3, 0]],
+ values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
+ dense_shape=[3, 4, 1])
+ ]
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ def testWindowShapeError(self):
+
+ def generator():
+ yield [1.0, 2.0, 3.0]
+ yield [4.0, 5.0, 6.0]
+ yield [7.0, 8.0, 9.0, 10.0]
+
+ dataset = dataset_ops.Dataset.from_generator(
+ generator, dtypes.float32, output_shapes=[None]).window(
+ size=3, shift=1).flat_map(lambda x: x.batch(batch_size=3))
+ self.assertDatasetProduces(
+ dataset,
+ expected_error=(
+ errors.InvalidArgumentError,
+ r"Cannot batch tensors with different shapes in component 0. "
+ r"First element had shape \[3\] and element 2 had shape \[4\]."))
+
+ def testWindowIgnoreErrors(self):
+ input_values = np.float32([1., np.nan, 2., np.nan, 3.])
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
+ lambda x: array_ops.check_numerics(x, "message")).window(
+ size=2, shift=2, stride=2,
+ drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2))
+ self.assertDatasetProduces(
+ dataset, expected_output=[np.float32([1., 2.]),
+ np.float32([2., 3.])])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
deleted file mode 100644
index b60ec4e..0000000
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class ZipDatasetTest(test_base.DatasetTestBase):
-
- def testZipDataset(self):
- component_placeholders = [
- array_ops.placeholder(dtypes.int64),
- array_ops.placeholder(dtypes.int64),
- array_ops.placeholder(dtypes.float64)
- ]
-
- datasets = tuple([
- dataset_ops.Dataset.from_tensor_slices(component_placeholder)
- for component_placeholder in component_placeholders
- ])
- zipped = dataset_ops.Dataset.zip(datasets)
-
- iterator = zipped.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- equal_length_components = [
- np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 22),
- np.array([37.0, 38.0, 39.0, 40.0])
- ]
- sess.run(init_op, feed_dict={ph: value for ph, value in zip(
- component_placeholders, equal_length_components)})
- for i in range(4):
- results = self.evaluate(get_next)
- for component, result_component in zip(
- equal_length_components, results):
- self.assertAllEqual(component[i], result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
- sess.run(init_op, feed_dict={ph: value for ph, value in zip(
- component_placeholders, variable_length_components)})
- for i in range(2):
- results = self.evaluate(get_next)
- for component, result_component in zip(
- variable_length_components, results):
- self.assertAllEqual(component[i], result_component)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testNestedZipDataset(self):
- component_placeholders = [
- array_ops.placeholder(dtypes.int64, shape=[4, 20]),
- array_ops.placeholder(dtypes.int64, shape=[4, 22]),
- array_ops.placeholder(dtypes.float64, shape=[4])
- ]
-
- datasets = [
- dataset_ops.Dataset.from_tensor_slices(component_placeholder)
- for component_placeholder in component_placeholders
- ]
- zipped = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
-
- iterator = zipped.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([20], get_next[0].shape)
- self.assertEqual([22], get_next[1][0].shape)
- self.assertEqual([], get_next[1][1].shape)
-
- with self.cached_session() as sess:
- equal_length_components = [
- np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 22),
- np.array([37.0, 38.0, 39.0, 40.0])
- ]
- sess.run(init_op, feed_dict={ph: value for ph, value in zip(
- component_placeholders, equal_length_components)})
- for i in range(4):
- result1, (result2, result3) = self.evaluate(get_next)
- self.assertAllEqual(equal_length_components[0][i], result1)
- self.assertAllEqual(equal_length_components[1][i], result2)
- self.assertAllEqual(equal_length_components[2][i], result3)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/kernel_tests/zip_test.py b/tensorflow/python/data/kernel_tests/zip_test.py
new file mode 100644
index 0000000..477c9fa7
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/zip_test.py
@@ -0,0 +1,101 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.Dataset.zip()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class ZipTest(test_base.DatasetTestBase):
+
+ def testZipDataset(self):
+
+ def dataset_fn(components):
+ datasets = tuple([
+ dataset_ops.Dataset.from_tensor_slices(component)
+ for component in components
+ ])
+ return dataset_ops.Dataset.zip(datasets)
+
+ equal_length_components = [
+ np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 22),
+ np.array([37.0, 38.0, 39.0, 40.0])
+ ]
+
+ get_next = self.getNext(dataset_fn(equal_length_components))
+ for i in range(4):
+ results = self.evaluate(get_next())
+ for component, result_component in zip(equal_length_components, results):
+ self.assertAllEqual(component[i], result_component)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]]
+ get_next = self.getNext(dataset_fn(variable_length_components))
+ for i in range(2):
+ results = self.evaluate(get_next())
+ for component, result_component in zip(variable_length_components,
+ results):
+ self.assertAllEqual(component[i], result_component)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+ def testNestedZipDataset(self):
+
+ equal_length_components = [
+ np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 22),
+ np.array([37.0, 38.0, 39.0, 40.0])
+ ]
+ datasets = [
+ dataset_ops.Dataset.from_tensor_slices(component)
+ for component in equal_length_components
+ ]
+ dataset = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
+
+ self.assertEqual(
+ dataset.output_shapes,
+ (tensor_shape.TensorShape([20]),
+ (tensor_shape.TensorShape([22]), tensor_shape.TensorShape([]))))
+
+ get_next = self.getNext(dataset)
+ for i in range(4):
+ result1, (result2, result3) = self.evaluate(get_next())
+ self.assertAllEqual(equal_length_components[0][i], result1)
+ self.assertAllEqual(equal_length_components[1][i], result2)
+ self.assertAllEqual(equal_length_components[2][i], result3)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(get_next())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 18edc08..dcbb0f1 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -14,6 +14,7 @@
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:math_ops",
@@ -26,7 +27,9 @@
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//tensorflow/python/data/experimental/ops:stats_options",
+ "//tensorflow/python/data/experimental/ops:threading_options",
"//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:options",
"//tensorflow/python/data/util:random_seed",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/data/util:structure",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 4a11619..71175fc 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -27,10 +27,13 @@
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import stats_options
+from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import options as options_lib
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import sparse
+from tensorflow.python.data.util import structure as structure_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -44,6 +47,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
@@ -53,6 +57,9 @@
from tensorflow.python.util.tf_export import tf_export
+ops.NotDifferentiable("ReduceDataset")
+
+
@tf_export("data.Dataset", v1=[])
@six.add_metaclass(abc.ABCMeta)
class DatasetV2(object):
@@ -106,6 +113,14 @@
dataset = self
options = self.options()
+ if options.experimental_threading is not None:
+ t_options = options.experimental_threading
+ if t_options.private_threadpool_size is not None:
+ dataset = _PrivateThreadPoolDataset(dataset,
+ t_options.private_threadpool_size)
+ if t_options.max_intra_op_parallelism is not None:
+ dataset = _MaxIntraOpParallelismDataset(
+ dataset, t_options.max_intra_op_parallelism)
static_optimizations = options._static_optimizations() # pylint: disable=protected-access
if static_optimizations:
dataset = _OptimizeDataset(dataset, static_optimizations)
@@ -278,9 +293,9 @@
Note that if `tensors` contains a NumPy array, and eager execution is not
enabled, the values will be embedded in the graph as one or more
`tf.constant` operations. For large datasets (> 1 GB), this can waste
- memory and run into byte limits of graph serialization. If tensors contains
- one or more large NumPy arrays, consider the alternative described in
- [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
+ memory and run into byte limits of graph serialization. If `tensors`
+ contains one or more large NumPy arrays, consider the alternative described
+ in [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
Args:
tensors: A nested structure of tensors.
@@ -297,9 +312,9 @@
Note that if `tensors` contains a NumPy array, and eager execution is not
enabled, the values will be embedded in the graph as one or more
`tf.constant` operations. For large datasets (> 1 GB), this can waste
- memory and run into byte limits of graph serialization. If tensors contains
- one or more large NumPy arrays, consider the alternative described in
- [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
+ memory and run into byte limits of graph serialization. If `tensors`
+ contains one or more large NumPy arrays, consider the alternative described
+ in [this guide](https://tensorflow.org/guide/datasets#consuming_numpy_arrays).
Args:
tensors: A nested structure of tensors, each having the same size in the
@@ -566,7 +581,7 @@
```
Args:
- *args: follow same semantics as python's xrange.
+ *args: follows the same semantics as python's xrange.
len(args) == 1 -> start = 0, stop = args[0], step = 1
len(args) == 2 -> start = args[0], stop = args[1], step = 1
len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
@@ -852,10 +867,10 @@
Raises:
ValueError: if `num_shards` or `index` are illegal values. Note: error
- checking is done on a best-effort basis, and aren't guaranteed to be
- caught upon dataset creation. (e.g. providing in a placeholder tensor
- bypasses the early checking, and will instead result in an error during
- a session.run call.)
+ checking is done on a best-effort basis, and errors aren't guaranteed
+ to be caught upon dataset creation. (e.g. providing in a placeholder
+ tensor bypasses the early checking, and will instead result in an error
+ during a session.run call.)
"""
num_shards = ops.convert_to_tensor(
num_shards, name="num_shards", dtype=dtypes.int64)
@@ -892,7 +907,7 @@
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
- whether the last batch should be dropped in the case its has fewer than
+ whether the last batch should be dropped in the case it has fewer than
`batch_size` elements; the default behavior is not to drop the smaller
batch.
@@ -949,7 +964,7 @@
respective components. Defaults are `0` for numeric types and
the empty string for string types.
drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
- whether the last batch should be dropped in the case its has fewer than
+ whether the last batch should be dropped in the case it has fewer than
`batch_size` elements; the default behavior is not to drop the smaller
batch.
@@ -1370,10 +1385,9 @@
def with_options(self, options):
"""Returns a new `tf.data.Dataset` with the given options set.
- The options are "global" in the sense they apply to the entire input
- pipeline in which the `with_options` transformation is used. If options are
- set multiple times, they are merged if possible (see
- `tf.data.Options.merge()` for details).
+ The options are "global" in the sense they apply to the entire dataset.
+ If options are set multiple times, they are merged as long as different
+ options do not use different non-default values.
Args:
options: A `tf.data.Options` that identifies the options the use.
@@ -1382,7 +1396,7 @@
Dataset: A `Dataset` with the given options.
Raises:
- ValueError: if options are set more than once
+ ValueError: when an option is set more than once to a non-default value
"""
return _OptionsDataset(self, options)
@@ -1570,77 +1584,89 @@
@tf_export("data.Options")
-class Options(object):
+class Options(options_lib.OptionsBase):
"""Represents options for tf.data.Dataset.
- An `Options` object can be for instance used to control which static
+ An `Options` object can be, for instance, used to control which static
optimizations to apply or whether to use performance modeling to dynamically
tune the parallelism of operations such as `tf.data.Dataset.map` or
`tf.data.Dataset.interleave`.
"""
- for _name, _ty, _docstring in [
- ("experimental_autotune", bool,
- "Whether to dynamically adjust the values of tunable parameters (e.g. "
- "degrees of parallelism)."),
- ("experimental_deterministic", bool,
- "Whether the outputs need to be produced in deterministic order."),
- ("experimental_filter_fusion", bool,
- "Whether to fuse filter transformations."),
- ("experimental_hoist_random_uniform", bool,
- "Whether to hoist `tf.random_uniform()` ops out of map transformations."
- ),
- ("experimental_stats", stats_options.StatsOptions,
- "Associate the given statistics options with the dataset pipeline."),
- ("experimental_map_and_batch_fusion", bool,
- "Whether to fuse map and batch transformations."),
- ("experimental_map_and_filter_fusion", bool,
- "Whether to fuse map and filter transformations."),
- ("experimental_map_fusion", bool, "Whether to fuse map transformations."),
- ("experimental_map_parallelization", bool,
- "Whether to parallelize stateless map transformations."),
- ("experimental_map_vectorization", bool,
- "Whether to vectorize map transformations."),
- ("experimental_noop_elimination", bool,
- "Whether to eliminate no-op transformations."),
- ("experimental_shuffle_and_repeat_fusion", bool,
- "Whether to fuse shuffle and repeat transformations."),
- ("experimental_numa_aware", bool,
- "Whether to use NUMA-aware operations."),
- ]:
- def _make_getter(name): # pylint: disable=no-self-argument
+ experimental_autotune = options_lib.create_option(
+ name="experimental_autotune",
+ ty=bool,
+ docstring=
+ "Whether to dynamically adjust the values of tunable parameters (e.g. "
+ "degrees of parallelism).")
- def getter(self):
- return getattr(self, "_" + name)
+ experimental_deterministic = options_lib.create_option(
+ name="experimental_deterministic",
+ ty=bool,
+ docstring=
+ "Whether to dynamically adjust the values of tunable parameters (e.g. "
+ "degrees of parallelism).")
- return getter
+ experimental_filter_fusion = options_lib.create_option(
+ name="experimental_filter_fusion",
+ ty=bool,
+ docstring="Whether to fuse filter transformations.")
- def _make_setter(name, ty): # pylint: disable=no-self-argument
+ experimental_hoist_random_uniform = options_lib.create_option(
+ name="experimental_hoist_random_uniform",
+ ty=bool,
+ docstring=
+ "Whether to hoist `tf.random_uniform()` ops out of map transformations.")
- def setter(self, value):
- if not isinstance(value, ty):
- raise TypeError(
- "Attempting to set the option %s to incompatible value: %r when "
- "it expects %r" % (name, value, ty))
- setattr(self, "_" + name, value)
+ experimental_map_and_batch_fusion = options_lib.create_option(
+ name="experimental_map_and_batch_fusion",
+ ty=bool,
+ docstring="Whether to fuse map and batch transformations.")
- return setter
+ experimental_map_and_filter_fusion = options_lib.create_option(
+ name="experimental_map_and_filter_fusion",
+ ty=bool,
+ docstring="Whether to fuse map and filter transformations.")
- vars()["_" + _name] = None
- vars()[_name] = property(
- _make_getter(_name), _make_setter(_name, _ty), None, _docstring)
+ experimental_map_fusion = options_lib.create_option(
+ name="experimental_map_and_filter_fusion",
+ ty=bool,
+ docstring="Whether to fuse map transformations.")
- def __init__(self):
- pass
+ experimental_map_parallelization = options_lib.create_option(
+ name="experimental_map_parallelization",
+ ty=bool,
+ docstring="Whether to parallelize stateless map transformations.")
- def __eq__(self, other):
- if isinstance(other, self.__class__):
- return self.__dict__ == other.__dict__
- else:
- return False
+ experimental_map_vectorization = options_lib.create_option(
+ name="experimental_map_vectorization",
+ ty=bool,
+ docstring="Whether to vectorize map transformations.")
- def __ne__(self, other):
- return not self.__eq__(other)
+ experimental_noop_elimination = options_lib.create_option(
+ name="experimental_noop_elimination",
+ ty=bool,
+ docstring="Whether to eliminate no-op transformations.")
+
+ experimental_numa_aware = options_lib.create_option(
+ name="experimental_numa_aware",
+ ty=bool,
+ docstring="Whether to use NUMA-aware operations.")
+
+ experimental_shuffle_and_repeat_fusion = options_lib.create_option(
+ name="experimental_shuffle_and_repeat_fusion",
+ ty=bool,
+ docstring="Whether to fuse shuffle and repeat transformations.")
+
+ experimental_stats = options_lib.create_option(
+ name="experimental_stats",
+ ty=stats_options.StatsOptions,
+ docstring="Associates the given statistics options with the dataset.")
+
+ experimental_threading = options_lib.create_option(
+ name="experimental_threading",
+ ty=threading_options.ThreadingOptions,
+ docstring="Associates the given threading options with the dataset.")
def _static_optimizations(self):
"""Produces the list of enabled static optimizations."""
@@ -1686,32 +1712,7 @@
New `tf.data.Options()` object which is the result of merging self with
the input `tf.data.Options`.
"""
- result = Options()
- for other in [self, options]:
- for name in [
- "experimental_autotune",
- "experimental_deterministic",
- "experimental_filter_fusion",
- "experimental_hoist_random_uniform",
- "experimental_map_and_batch_fusion",
- "experimental_map_and_filter_fusion",
- "experimental_map_fusion",
- "experimental_map_parallelization",
- "experimental_map_vectorization",
- "experimental_noop_elimination",
- "experimental_numa_aware",
- "experimental_shuffle_and_repeat_fusion",
- "experimental_stats",
- ]:
- this = getattr(result, name)
- that = getattr(other, name)
- if that is not None:
- if this is None:
- setattr(result, name, that)
- elif this != that:
- raise ValueError(
- "Cannot merge incompatible values of option: %s" % (name))
- return result
+ return options_lib.merge_options(self, options)
class DatasetSource(DatasetV2):
@@ -1868,57 +1869,6 @@
return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64)
-class _NestedDatasetComponent(object):
- """The structure of a `Dataset` nested in a component of another `Dataset`.
-
- A `StructuredFunctionWrapper` around a function that returns a `Dataset` as
- one of its components will have a `NestedDatasetComponent` in the
- corresponding position in the `output_classes`, `output_shapes`, and
- `output_types` properties.
-
- TODO(b/110122868): Add this class, or something equivalent, to the public API.
- We are considering revising the public API for accessing Dataset structure
- (`output_classes` etc.) based on experience with nested datasets and other
- custom component types.
- """
-
- def __init__(self,
- dataset=None,
- output_shapes=None,
- output_types=None,
- output_classes=None):
- if dataset is None:
- if (output_classes is None or output_shapes is None or
- output_types is None):
- raise ValueError(
- "Either `dataset`, or all of `output_classes`, "
- "`output_shapes`, and `output_types` must be specified.")
- self._output_classes = output_classes
- self._output_shapes = output_shapes
- self._output_types = output_types
- else:
- if not (output_classes is None and output_shapes is None and
- output_types is None):
- raise ValueError(
- "Either `dataset`, or all of `output_classes`, "
- "`output_shapes`, and `output_types` must be specified.")
- self._output_classes = dataset.output_classes
- self._output_shapes = dataset.output_shapes
- self._output_types = dataset.output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
-
class _VariantDataset(DatasetV2):
"""A Dataset wrapper around a `tf.variant`-typed function argument."""
@@ -1935,15 +1885,73 @@
@property
def output_classes(self):
- return self._structure.output_classes
+ return self._structure._to_legacy_output_classes() # pylint: disable=protected-access
@property
def output_shapes(self):
- return self._structure.output_shapes
+ return self._structure._to_legacy_output_shapes() # pylint: disable=protected-access
@property
def output_types(self):
- return self._structure.output_types
+ return self._structure._to_legacy_output_types() # pylint: disable=protected-access
+
+
+class DatasetStructure(structure_lib.Structure):
+ """Represents a `Dataset` of structured values."""
+
+ def __init__(self, element_structure):
+ self._element_structure = element_structure
+
+ @property
+ def _flat_shapes(self):
+ return [tensor_shape.scalar()]
+
+ @property
+ def _flat_types(self):
+ return [dtypes.variant]
+
+ def is_compatible_with(self, other):
+ # pylint: disable=protected-access
+ return (isinstance(other, DatasetStructure) and
+ self._element_structure.is_compatible_with(
+ other._element_structure))
+
+ def _to_tensor_list(self, value):
+ return [value._as_variant_tensor()] # pylint: disable=protected-access
+
+ def _from_tensor_list(self, flat_value):
+ if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
+ not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "DatasetStructure corresponds to a single tf.variant scalar.")
+ return self._from_compatible_tensor_list(flat_value)
+
+ def _from_compatible_tensor_list(self, flat_value):
+ # pylint: disable=protected-access
+ return _VariantDataset(flat_value[0], self._element_structure)
+
+ @staticmethod
+ def from_value(value):
+ # TODO(b/110122868): We can simplify this when a `Dataset` object has a
+ # `Structure`-valued property.
+ element_structure = structure_lib.Structure._from_legacy_structure(
+ value.output_types, value.output_shapes, value.output_classes)
+ return DatasetStructure(element_structure)
+
+ def _to_legacy_output_types(self):
+ return self
+
+ def _to_legacy_output_shapes(self):
+ return self
+
+ def _to_legacy_output_classes(self):
+ return self
+
+
+# pylint: disable=protected-access
+structure_lib.Structure._register_custom_converter(DatasetV2,
+ DatasetStructure.from_value)
+# pylint: enable=protected-access
class StructuredFunctionWrapper(object):
@@ -2001,6 +2009,9 @@
self._input_types = dataset.output_types
self._input_classes = dataset.output_classes
+ self._input_structure = structure_lib.Structure._from_legacy_structure( # pylint: disable=protected-access
+ self._input_types, self._input_shapes, self._input_classes)
+
self._transformation_name = transformation_name
readable_transformation_name = transformation_name.replace(
".", "_")[:-2] if len(transformation_name) > 2 else ""
@@ -2008,35 +2019,18 @@
readable_transformation_name,
function_utils.get_func_name(func),
str(ops.uid())
-
])
if defun_kwargs is None:
defun_kwargs = {}
@function.Defun(
- *self._defun_args(), func_name=self._func_name, **defun_kwargs)
+ *self._input_structure._flat_types, func_name=self._func_name, # pylint: disable=protected-access
+ **defun_kwargs)
def tf_data_structured_function_wrapper(*args):
"""Wrapper for passing nested structures to and from tf.data functions."""
- flat_args = []
- for arg, arg_class, arg_shape, arg_type in zip(
- args,
- nest.flatten(self._input_classes),
- nest.flatten(self._input_shapes),
- nest.flatten(self._input_types)):
- # TODO(b/110122868): Add a registration mechanism for new component
- # types.
- if arg_class is sparse_tensor_lib.SparseTensor:
- arg = sparse.deserialize_sparse_tensors(
- arg, arg_type, arg_shape, arg_class)
- arg.indices.set_shape([None, arg_shape.ndims])
- arg.dense_shape.set_shape([arg_shape.ndims])
- elif isinstance(arg_class, _NestedDatasetComponent):
- arg = _VariantDataset(arg, arg_class)
- else:
- arg.set_shape(arg_shape)
- flat_args.append(arg)
- nested_args = nest.pack_sequence_as(self._input_classes, flat_args)
+ # pylint: disable=protected-access
+ nested_args = self._input_structure._from_compatible_tensor_list(args)
if not _should_unpack_args(nested_args):
nested_args = (nested_args,)
@@ -2054,50 +2048,14 @@
if isinstance(ret, list):
ret = tuple(ret)
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- flat_ret = []
- flat_classes = []
- flat_shapes = []
- flat_types = []
- for t in nest.flatten(ret):
- # TODO(b/110122868): Add a registration mechanism for new component
- # types.
- if sparse_tensor_lib.is_sparse(t):
- t = sparse_tensor_lib.SparseTensor.from_value(t)
- flat_ret.append(sparse.serialize_sparse_tensors(t))
- flat_classes.append(sparse_tensor_lib.SparseTensor)
- flat_shapes.append(t.get_shape())
- flat_types.append(t.dtype)
- elif isinstance(t, DatasetV2):
- flat_ret.append(t._as_variant_tensor()) # pylint: disable=protected-access
- component = _NestedDatasetComponent(t)
- flat_classes.append(component)
- flat_shapes.append(component)
- flat_types.append(component)
- if t.options() != Options():
- warnings.warn("Encountered a nested dataset with non-default "
- "options. These options will not be propagated to "
- "the outer dataset.")
- else:
- try:
- t = ops.convert_to_tensor(t)
- except (ValueError, TypeError):
- raise TypeError("Unsupported return value from function passed to "
- "%s: %s." % (transformation_name, t))
- flat_ret.append(t)
- flat_classes.append(ops.Tensor)
- flat_shapes.append(t.get_shape())
- flat_types.append(t.dtype)
-
- ret = nest.pack_sequence_as(ret, flat_ret)
- self._output_classes = nest.pack_sequence_as(ret, flat_classes)
- self._output_shapes = nest.pack_sequence_as(ret, flat_shapes)
- self._output_types = nest.pack_sequence_as(ret, flat_types)
+ try:
+ self._output_structure = structure_lib.Structure.from_value(ret)
+ except (ValueError, TypeError):
+ raise TypeError("Unsupported return value from function passed to "
+ "%s: %s." % (transformation_name, ret))
_warn_if_collections(transformation_name)
-
- return flat_ret
+ return self._output_structure._to_tensor_list(ret)
self._function = tf_data_structured_function_wrapper
if add_to_graph:
@@ -2108,32 +2066,21 @@
# in case (e.g.) we need to rerun the function.
self._function._create_definition_if_needed() # pylint: disable=protected-access
- def _defun_args(self):
- """Returns a flat list of `tf.DType` for the input element structure."""
- ret = []
- for input_type, input_class in zip(nest.flatten(self._input_types),
- nest.flatten(self._input_classes)):
- # TODO(b/110122868): Add a registration mechanism for new component types.
- if input_class is sparse_tensor_lib.SparseTensor:
- ret.append(dtypes.variant)
- elif isinstance(input_class, _NestedDatasetComponent):
- ret.append(dtypes.variant)
- else:
- assert isinstance(input_type, dtypes.DType)
- ret.append(input_type)
- return ret
+ @property
+ def output_structure(self):
+ return self._output_structure
@property
def output_classes(self):
- return self._output_classes
+ return self._output_structure._to_legacy_output_classes() # pylint: disable=protected-access
@property
def output_shapes(self):
- return self._output_shapes
+ return self._output_structure._to_legacy_output_shapes() # pylint: disable=protected-access
@property
def output_types(self):
- return self._output_types
+ return self._output_structure._to_legacy_output_types() # pylint: disable=protected-access
@property
def function(self):
@@ -2156,30 +2103,12 @@
A dictionary of keyword arguments that can be passed to many Dataset op
constructors.
"""
- output_classes = []
- output_shapes = []
- output_types = []
- for output_class, output_shape, output_type in zip(
- nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes),
- nest.flatten(dataset.output_types)):
- if isinstance(output_class, _NestedDatasetComponent):
- output_classes.append(output_class.output_classes)
- output_shapes.append(output_shape.output_shapes)
- output_types.append(output_type.output_types)
- else:
- output_classes.append(output_class)
- output_shapes.append(output_shape)
- output_types.append(output_type)
-
- output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes)
- output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes)
- output_types = nest.pack_sequence_as(dataset.output_types, output_types)
-
+ # pylint: disable=protected-access
+ structure = structure_lib.Structure._from_legacy_structure(
+ dataset.output_types, dataset.output_shapes, dataset.output_classes)
return {
- "output_shapes":
- nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)),
- "output_types":
- nest.flatten(sparse.as_dense_types(output_types, output_classes)),
+ "output_shapes": structure._flat_shapes,
+ "output_types": structure._flat_types,
}
@@ -2840,30 +2769,6 @@
return "Dataset.map()"
-class MatchingFilesDataset(DatasetSource):
- """A `Dataset` that list the files according to the input patterns."""
-
- def __init__(self, patterns):
- super(MatchingFilesDataset, self).__init__()
- self._patterns = ops.convert_to_tensor(
- patterns, dtype=dtypes.string, name="patterns")
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.matching_files_dataset(self._patterns)
-
- @property
- def output_classes(self):
- return ops.Tensor
-
- @property
- def output_shapes(self):
- return tensor_shape.scalar()
-
- @property
- def output_types(self):
- return dtypes.string
-
-
class ParallelMapDataset(MapDataset):
"""A `Dataset` that maps a function over elements in its input in parallel."""
@@ -2902,11 +2807,13 @@
wrapped_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
- if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent):
+ if not isinstance(wrapped_func.output_structure, DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
- self._output_classes = wrapped_func.output_classes.output_classes
- self._output_types = wrapped_func.output_types.output_types
- self._output_shapes = wrapped_func.output_shapes.output_shapes
+ # pylint: disable=protected-access
+ element_structure = wrapped_func.output_structure._element_structure
+ self._output_classes = element_structure._to_legacy_output_classes()
+ self._output_types = element_structure._to_legacy_output_types()
+ self._output_shapes = element_structure._to_legacy_output_shapes()
self._map_func = wrapped_func.function
def _as_variant_tensor(self):
@@ -3048,10 +2955,9 @@
self._output_classes = nest.pack_sequence_as(
input_dataset.output_classes,
[
- _NestedDatasetComponent( # pylint: disable=protected-access
- output_classes=output_class,
- output_shapes=output_shape,
- output_types=output_type)
+ DatasetStructure(
+ structure_lib.Structure._from_legacy_structure( # pylint: disable=protected-access
+ output_type, output_shape, output_class))
for output_class, output_shape, output_type in zip(
nest.flatten(input_dataset.output_classes),
nest.flatten(input_dataset.output_shapes),
@@ -3135,7 +3041,7 @@
class _SetStatsAggregatorDataset(UnaryUnchangedStructureDataset):
- """A `Dataset` that acts as an identity, and sets stats aggregator."""
+ """A `Dataset` that acts as an identity, and sets a stats aggregator."""
def __init__(self, input_dataset, aggregator, prefix, counter_prefix):
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
@@ -3151,3 +3057,37 @@
self._prefix,
self._counter_prefix,
**flat_structure(self))
+
+
+class _MaxIntraOpParallelismDataset(UnaryUnchangedStructureDataset):
+ """A `Dataset` that acts as an identity, overriding intra-op parallelism."""
+
+ def __init__(self, input_dataset, max_intra_op_parallelism):
+ super(_MaxIntraOpParallelismDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._max_intra_op_parallelism = ops.convert_to_tensor(
+ max_intra_op_parallelism,
+ dtype=dtypes.int64,
+ name="max_intra_op_parallelism")
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_max_intra_op_parallelism_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._max_intra_op_parallelism,
+ **flat_structure(self))
+
+
+class _PrivateThreadPoolDataset(UnaryUnchangedStructureDataset):
+ """A `Dataset` that acts as an identity, setting a private threadpool."""
+
+ def __init__(self, input_dataset, num_threads):
+ super(_PrivateThreadPoolDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._num_threads = ops.convert_to_tensor(
+ num_threads, dtype=dtypes.int64, name="num_threads")
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_private_thread_pool_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._num_threads,
+ **flat_structure(self))
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 68b03ba..e2ca64c 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -68,7 +68,7 @@
return not bool(device_stack)
-@tf_export("data.Iterator")
+@tf_export(v1=["data.Iterator"])
class Iterator(checkpointable.CheckpointableBase):
"""Represents the state of iterating through a `Dataset`."""
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index 91cf883..4113b7e 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -183,19 +183,13 @@
return OptionalStructure(value.value_structure)
def _to_legacy_output_types(self):
- raise NotImplementedError("The `output_types` property is not supported on "
- "structured objects containing an `Optional`. "
- "Use the corresponding `structure` property.")
+ return self
def _to_legacy_output_shapes(self):
- raise NotImplementedError("The `output_shapes` property is not supported on"
- " structured objects containing an `Optional`. "
- "Use the corresponding `structure` property.")
+ return self
def _to_legacy_output_classes(self):
- raise NotImplementedError("The `output_classes` property is not supported "
- "on structured objects containing an `Optional`. "
- "Use the corresponding `structure` property.")
+ return self
# pylint: disable=protected-access
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index 880e005..a93f3e4 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -180,7 +180,7 @@
def __init__(self, filenames, compression_type=None, buffer_size=None,
num_parallel_reads=None):
- """Creates a `TFRecordDataset` to read for one or more TFRecord files.
+ """Creates a `TFRecordDataset` to read one or more TFRecord files.
NOTE: The `num_parallel_reads` argument can be used to improve performance
when reading from a remote filesystem.
diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD
index 39082ce..f15ebc3 100644
--- a/tensorflow/python/data/util/BUILD
+++ b/tensorflow/python/data/util/BUILD
@@ -98,6 +98,23 @@
)
py_library(
+ name = "options",
+ srcs = ["options.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "options_test",
+ size = "small",
+ srcs = ["options_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":options",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
name = "convert",
srcs = ["convert.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/data/util/convert_test.py b/tensorflow/python/data/util/convert_test.py
index 4a5b730..3058e2b 100644
--- a/tensorflow/python/data/util/convert_test.py
+++ b/tensorflow/python/data/util/convert_test.py
@@ -30,47 +30,52 @@
def testInteger(self):
resp = convert.optional_param_to_tensor("foo", 3)
- with self.cached_session() as sess:
- self.assertEqual(3, self.evaluate(resp))
+ self.assertEqual(3, self.evaluate(resp))
def testIntegerDefault(self):
resp = convert.optional_param_to_tensor("foo", None)
- with self.cached_session() as sess:
- self.assertEqual(0, self.evaluate(resp))
+ self.assertEqual(0, self.evaluate(resp))
def testStringDefault(self):
resp = convert.optional_param_to_tensor("bar", None, "default",
dtypes.string)
- with self.cached_session() as sess:
- self.assertEqual(compat.as_bytes("default"), self.evaluate(resp))
+ self.assertEqual(compat.as_bytes("default"), self.evaluate(resp))
def testString(self):
resp = convert.optional_param_to_tensor("bar", "value", "default",
dtypes.string)
- with self.cached_session() as sess:
- self.assertEqual(compat.as_bytes("value"), self.evaluate(resp))
+ self.assertEqual(compat.as_bytes("value"), self.evaluate(resp))
def testPartialShapeToTensorKnownDimension(self):
- with self.cached_session() as sess:
- self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
- tensor_shape.TensorShape([1]))))
- self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor((1,))))
- self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor([1])))
- self.assertAllEqual([1], sess.run(convert.partial_shape_to_tensor(
- constant_op.constant([1], dtype=dtypes.int64))))
+ self.assertAllEqual([1],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([1]))))
+ self.assertAllEqual([1], self.evaluate(
+ convert.partial_shape_to_tensor((1,))))
+ self.assertAllEqual([1], self.evaluate(
+ convert.partial_shape_to_tensor([1])))
+ self.assertAllEqual([1],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ constant_op.constant([1], dtype=dtypes.int64))))
def testPartialShapeToTensorUnknownDimension(self):
- with self.cached_session() as sess:
- self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
- tensor_shape.TensorShape([None]))))
- self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
- (None,))))
- self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
- [None])))
- self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
- [-1])))
- self.assertAllEqual([-1], sess.run(convert.partial_shape_to_tensor(
- constant_op.constant([-1], dtype=dtypes.int64))))
+ self.assertAllEqual([-1],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([None]))))
+ self.assertAllEqual([-1],
+ self.evaluate(convert.partial_shape_to_tensor((None,))))
+ self.assertAllEqual([-1],
+ self.evaluate(convert.partial_shape_to_tensor([None])))
+ self.assertAllEqual([-1],
+ self.evaluate(convert.partial_shape_to_tensor([-1])))
+ self.assertAllEqual([-1],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ constant_op.constant([-1],
+ dtype=dtypes.int64))))
with self.assertRaisesRegexp(
ValueError, r"The given shape .* must be a 1-D tensor of tf.int64 "
@@ -84,42 +89,63 @@
convert.partial_shape_to_tensor(constant_op.constant([1., 1.]))
def testPartialShapeToTensorMultipleDimensions(self):
- with self.cached_session() as sess:
- self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
- tensor_shape.TensorShape([3, 6]))))
- self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
- (3, 6))))
- self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
- [3, 6])))
- self.assertAllEqual([3, 6], sess.run(convert.partial_shape_to_tensor(
- constant_op.constant([3, 6], dtype=dtypes.int64))))
+ self.assertAllEqual([3, 6],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([3, 6]))))
+ self.assertAllEqual([3, 6],
+ self.evaluate(convert.partial_shape_to_tensor((3, 6))))
+ self.assertAllEqual([3, 6],
+ self.evaluate(convert.partial_shape_to_tensor([3, 6])))
+ self.assertAllEqual([3, 6],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ constant_op.constant([3, 6],
+ dtype=dtypes.int64))))
- self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
- tensor_shape.TensorShape([3, None]))))
- self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
- (3, None))))
- self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
- [3, None])))
- self.assertAllEqual([3, -1], sess.run(convert.partial_shape_to_tensor(
- constant_op.constant([3, -1], dtype=dtypes.int64))))
+ self.assertAllEqual([3, -1],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([3, None]))))
+ self.assertAllEqual([3, -1],
+ self.evaluate(
+ convert.partial_shape_to_tensor((3, None))))
+ self.assertAllEqual([3, -1],
+ self.evaluate(
+ convert.partial_shape_to_tensor([3, None])))
+ self.assertAllEqual([3, -1],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ constant_op.constant([3, -1],
+ dtype=dtypes.int64))))
- self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
- tensor_shape.TensorShape([None, None]))))
- self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
- (None, None))))
- self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
- [None, None])))
- self.assertAllEqual([-1, -1], sess.run(convert.partial_shape_to_tensor(
- constant_op.constant([-1, -1], dtype=dtypes.int64))))
+ self.assertAllEqual([-1, -1],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([None, None]))))
+ self.assertAllEqual([-1, -1],
+ self.evaluate(
+ convert.partial_shape_to_tensor((None, None))))
+ self.assertAllEqual([-1, -1],
+ self.evaluate(
+ convert.partial_shape_to_tensor([None, None])))
+ self.assertAllEqual([-1, -1],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ constant_op.constant([-1, -1],
+ dtype=dtypes.int64))))
def testPartialShapeToTensorScalar(self):
- with self.cached_session() as sess:
- self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
- tensor_shape.TensorShape([]))))
- self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(())))
- self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor([])))
- self.assertAllEqual([], sess.run(convert.partial_shape_to_tensor(
- constant_op.constant([], dtype=dtypes.int64))))
+ self.assertAllEqual([],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ tensor_shape.TensorShape([]))))
+ self.assertAllEqual([], self.evaluate(convert.partial_shape_to_tensor(())))
+ self.assertAllEqual([], self.evaluate(convert.partial_shape_to_tensor([])))
+ self.assertAllEqual([],
+ self.evaluate(
+ convert.partial_shape_to_tensor(
+ constant_op.constant([], dtype=dtypes.int64))))
if __name__ == "__main__":
diff --git a/tensorflow/python/data/util/options.py b/tensorflow/python/data/util/options.py
new file mode 100644
index 0000000..9badba8
--- /dev/null
+++ b/tensorflow/python/data/util/options.py
@@ -0,0 +1,131 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for tf.data options."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def _internal_attr_name(name):
+ return "_" + name
+
+
+class OptionsBase(object):
+ """Base class for representing a set of tf.data options.
+
+ Attributes:
+ _options: Stores the option values.
+ """
+
+ def __init__(self):
+ self._options = {}
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return NotImplemented
+ for name in set(self._options) | set(other._options): # pylint: disable=protected-access
+ if getattr(self, name) != getattr(other, name):
+ return False
+ return True
+
+ def __ne__(self, other):
+ if isinstance(other, self.__class__):
+ return not self.__eq__(other)
+ else:
+ return NotImplemented
+
+
+def create_option(name, ty, docstring, default=None):
+ """Creates a type-checked property.
+
+ Args:
+ name: the name to use
+ ty: the type to use
+ docstring: the docstring to use
+ default: the default value to use
+
+ Returns:
+ A type-checked property.
+ """
+
+ def get_fn(self):
+ return self._options.get(name, default) # pylint: disable=protected-access
+
+ def set_fn(self, value):
+ if not isinstance(value, ty):
+ raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
+ (name, ty, value, type(value)))
+ self._options[name] = value # pylint: disable=protected-access
+
+ return property(get_fn, set_fn, None, docstring)
+
+
+def merge_options(*options_list):
+ """Merges the given options, returning the result as a new options object.
+
+ The input arguments are expected to have a matching type that derives from
+ `OptionsBase` (and thus each represent a set of options). The method outputs
+ an object of the same type created by merging the sets of options represented
+ by the input arguments.
+
+ The sets of options can be merged as long as there does not exist an option
+ with different non-default values.
+
+ If an option is an instance of `OptionsBase` itself, then this method is
+ applied recursively to the set of options represented by this option.
+
+ Args:
+ *options_list: options to merge
+
+ Raises:
+ TypeError: if the input arguments are incompatible or not derived from
+ `OptionsBase`
+ ValueError: if the given options cannot be merged
+
+ Returns:
+ A new options object which is the result of merging the given options.
+ """
+ if len(options_list) < 1:
+ raise ValueError("At least one options should be provided")
+ result_type = type(options_list[0])
+
+ for options in options_list:
+ if not isinstance(options, result_type):
+ raise TypeError("Incompatible options type: %r vs %r" % (type(options),
+ result_type))
+
+ if not isinstance(options_list[0], OptionsBase):
+ raise TypeError("The inputs should inherit from `OptionsBase`")
+
+ default_options = result_type()
+ result = result_type()
+ for options in options_list:
+ # Iterate over all set options and merge the into the result.
+ for name in options._options: # pylint: disable=protected-access
+ this = getattr(result, name)
+ that = getattr(options, name)
+ default = getattr(default_options, name)
+ if that == default:
+ continue
+ elif this == default:
+ setattr(result, name, that)
+ elif isinstance(this, OptionsBase):
+ setattr(result, name, merge_options(this, that))
+ elif this != that:
+ raise ValueError(
+ "Cannot merge incompatible values (%r and %r) of option: %s" %
+ (this, that, name))
+ return result
diff --git a/tensorflow/python/data/util/options_test.py b/tensorflow/python/data/util/options_test.py
new file mode 100644
index 0000000..c516983
--- /dev/null
+++ b/tensorflow/python/data/util/options_test.py
@@ -0,0 +1,96 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for dataset options utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.util import options
+from tensorflow.python.platform import test
+
+
+class _TestOptions(options.OptionsBase):
+ x = options.create_option(
+ name="x", ty=int, docstring="the answer to everything", default=42)
+ y = options.create_option(
+ name="y", ty=float, docstring="a tasty pie", default=3.14)
+
+
+class _NestedTestOptions(options.OptionsBase):
+ opts = options.create_option(
+ name="opts", ty=_TestOptions, docstring="nested options")
+
+
+class OptionsTest(test.TestCase):
+
+ def testDocumentation(self):
+ self.assertEqual(_TestOptions.x.__doc__, "the answer to everything")
+ self.assertEqual(_TestOptions.y.__doc__, "a tasty pie")
+
+ def testCreateOption(self):
+ opts = _TestOptions()
+ self.assertEqual(opts.x, 42)
+ self.assertEqual(opts.y, 3.14)
+ self.assertIsInstance(opts.x, int)
+ self.assertIsInstance(opts.y, float)
+ opts.x = 0
+ self.assertEqual(opts.x, 0)
+ with self.assertRaises(TypeError):
+ opts.x = 3.14
+ opts.y = 0.0
+ self.assertEqual(opts.y, 0.0)
+ with self.assertRaises(TypeError):
+ opts.y = 42
+
+ def testMergeOptions(self):
+ options1, options2 = _TestOptions(), _TestOptions()
+ with self.assertRaises(ValueError):
+ options.merge_options()
+ merged_options = options.merge_options(options1, options2)
+ self.assertEqual(merged_options.x, 42)
+ self.assertEqual(merged_options.y, 3.14)
+ options1.x = 0
+ options2.y = 0.0
+ merged_options = options.merge_options(options1, options2)
+ self.assertEqual(merged_options.x, 0)
+ self.assertEqual(merged_options.y, 0.0)
+
+ def testMergeNestedOptions(self):
+ options1, options2 = _NestedTestOptions(), _NestedTestOptions()
+ merged_options = options.merge_options(options1, options2)
+ self.assertEqual(merged_options.opts, None)
+ options1.opts = _TestOptions()
+ merged_options = options.merge_options(options1, options2)
+ self.assertEqual(merged_options.opts, _TestOptions())
+ options2.opts = _TestOptions()
+ merged_options = options.merge_options(options1, options2)
+ self.assertEqual(merged_options.opts, _TestOptions())
+ options1.opts.x = 0
+ options2.opts.y = 0.0
+ merged_options = options.merge_options(options1, options2)
+ self.assertEqual(merged_options.opts.x, 0)
+ self.assertEqual(merged_options.opts.y, 0.0)
+
+ def testMergeOptionsInvalid(self):
+ with self.assertRaises(TypeError):
+ options.merge_options(0)
+ options1, options2 = _TestOptions(), _NestedTestOptions()
+ with self.assertRaises(TypeError):
+ options.merge_options(options1, options2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/util/sparse.py b/tensorflow/python/data/util/sparse.py
index 5e6d224..f2e22fe 100644
--- a/tensorflow/python/data/util/sparse.py
+++ b/tensorflow/python/data/util/sparse.py
@@ -34,7 +34,7 @@
Returns:
`True` if `classes` contains a sparse tensor type and `False` otherwise.
"""
- return any([c is sparse_tensor.SparseTensor for c in nest.flatten(classes)])
+ return any(c is sparse_tensor.SparseTensor for c in nest.flatten(classes))
def as_dense_shapes(shapes, classes):
diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py
index 9a31182..3cf67b0 100644
--- a/tensorflow/python/data/util/structure.py
+++ b/tensorflow/python/data/util/structure.py
@@ -208,14 +208,16 @@
flat_ret = []
for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
flat_classes):
- if issubclass(flat_class, sparse_tensor_lib.SparseTensor):
+ if isinstance(flat_class, Structure):
+ flat_ret.append(flat_class)
+ elif issubclass(flat_class, sparse_tensor_lib.SparseTensor):
flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
elif issubclass(flat_class, ops.Tensor):
flat_ret.append(TensorStructure(flat_type, flat_shape))
else:
# NOTE(mrry): Since legacy structures produced by iterators only
- # comprise Tensors, SparseTensors, and nests, we do not need to support
- # all structure types here.
+ # comprise Tensors, SparseTensors, and nests, we do not need to
+ # support all structure types here.
raise TypeError(
"Could not build a structure for output class %r" % flat_type)
@@ -381,6 +383,13 @@
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
+ # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
+ # op here and return that, instead of mutating the input's shape using
+ # `Tensor.set_shape()`. However, that would add extra ops on the arguments
+ # of each `tf.data` function, which could impact performance. When this
+ # bug is resolved, we should be able to add the `ensure_shape()` ops and
+ # optimize them away using contextual shape information.
+ flat_value[0].set_shape(self._shape)
return flat_value[0]
@staticmethod
@@ -406,7 +415,11 @@
@property
def _flat_shapes(self):
- return [tensor_shape.vector(3)]
+ # NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
+ # but a `SparseTensorStructure` can also represent a batch of boxed
+ # `SparseTensor` objects with shape `(?, 3)` (and batches of batches, etc.),
+ # so the flat shape must be unknown.
+ return [tensor_shape.unknown_shape(None)]
@property
def _flat_types(self):
@@ -428,8 +441,11 @@
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
- return sparse_ops.deserialize_sparse(
+ ret = sparse_ops.deserialize_sparse(
flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims)
+ ret.indices.set_shape([None, self._dense_shape.ndims])
+ ret.dense_shape.set_shape([self._dense_shape.ndims])
+ return ret
@staticmethod
def from_value(value):
diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py
index 630a0c9..65a41a5 100644
--- a/tensorflow/python/data/util/structure_test.py
+++ b/tensorflow/python/data/util/structure_test.py
@@ -44,7 +44,7 @@
[dtypes.float32], [[]]),
(lambda: sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
- structure.SparseTensorStructure, [dtypes.variant], [[3]]),
+ structure.SparseTensorStructure, [dtypes.variant], [None]),
(lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
(lambda: {
@@ -58,14 +58,17 @@
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
}, structure.NestedStructure,
- [dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]]))
+ [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]))
def testFlatStructure(self, value_fn, expected_structure, expected_types,
expected_shapes):
value = value_fn()
s = structure.Structure.from_value(value)
self.assertIsInstance(s, expected_structure)
self.assertEqual(expected_types, s._flat_types)
- self.assertEqual(expected_shapes, s._flat_shapes)
+ for expected, actual in zip(expected_shapes, s._flat_shapes):
+ self.assertTrue(actual.is_compatible_with(expected))
+ self.assertTrue(
+ tensor_shape.as_shape(expected).is_compatible_with(actual))
@parameterized.parameters(
(lambda: constant_op.constant(37.0), lambda: [
diff --git a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
index b78c3d1..74498c8 100644
--- a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
@@ -131,8 +131,8 @@
with session.Session(
config=self.session_config, graph=graph,
target=self.server_target) as sess:
- self.evaluate(self.a.initializer)
- self.evaluate(self.b.initializer)
+ sess.run(self.a.initializer)
+ sess.run(self.b.initializer)
run_options = config_pb2.RunOptions()
debug_utils.watch_graph(
@@ -198,8 +198,8 @@
with session.Session(
config=self.session_config, graph=graph,
target=self.server_target) as sess:
- self.evaluate(self.a.initializer)
- self.evaluate(self.b.initializer)
+ sess.run(self.a.initializer)
+ sess.run(self.b.initializer)
def watch_fn(feeds, fetch_keys):
del feeds, fetch_keys
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 83c3901..5afbcec 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -50,6 +50,7 @@
srcs_version = "PY2AND3",
deps = [
":cross_device_utils",
+ ":device_util",
":reduce_util",
":values",
"//tensorflow/python:array_ops",
@@ -58,8 +59,6 @@
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"@six_archive//:six",
],
@@ -84,6 +83,67 @@
)
py_library(
+ name = "device_util",
+ srcs = ["device_util.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:device",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "device_util_test",
+ srcs = ["device_util_test.py"],
+ additional_deps = [
+ ":device_util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+py_library(
+ name = "distribute_lib",
+ srcs = [
+ "distribute_lib.py",
+ "distribution_strategy_context.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":device_util",
+ ":reduce_util",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data",
+ "//tensorflow/python/ops/losses",
+ "//tensorflow/tools/docs:doc_controls",
+ ],
+)
+
+py_test(
+ name = "distribute_lib_test",
+ size = "small",
+ srcs = ["distribute_lib_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":distribute_lib",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variable_scope",
+ ],
+)
+
+py_library(
name = "distribute_config",
srcs = [
"distribute_config.py",
@@ -140,6 +200,35 @@
)
py_library(
+ name = "mirrored_strategy",
+ srcs = ["mirrored_strategy.py"],
+ deps = [
+ ":cross_device_ops",
+ ":device_util",
+ ":distribute_lib",
+ ":multi_worker_util",
+ ":reduce_util",
+ ":shared_variable_creator",
+ ":values",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:device",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:pywrap_tensorflow",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:tape",
+ ],
+)
+
+py_library(
name = "multi_worker_util",
srcs = [
"multi_worker_util.py",
@@ -166,12 +255,12 @@
additional_deps = [
":input_ops",
"//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
"//tensorflow/python:errors",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:io_ops",
- "//tensorflow/python/data/ops:readers",
"//tensorflow/python:util",
],
tags = [
@@ -242,11 +331,11 @@
name = "values",
srcs = ["values.py"],
deps = [
+ ":device_util",
+ ":distribute_lib",
":input_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
- "//tensorflow/python:device_util",
- "//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py
index f55385e..a88ed62 100644
--- a/tensorflow/python/distribute/cross_device_ops.py
+++ b/tensorflow/python/distribute/cross_device_ops.py
@@ -23,6 +23,7 @@
from tensorflow.python.client import device_lib
from tensorflow.python.distribute import cross_device_utils
+from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
@@ -31,7 +32,6 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import device_util
def check_destinations(destinations):
@@ -103,10 +103,10 @@
# pylint: disable=g-missing-docstring
if not value_destination_pairs: return False
if not isinstance(value_destination_pairs, (list, tuple)): return False
- if not all([isinstance(pair, tuple) for pair in value_destination_pairs]):
+ if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
return False
- if not all([isinstance(v[0], value_lib.PerReplica)
- for v in value_destination_pairs]):
+ if not all(isinstance(v[0], value_lib.PerReplica)
+ for v in value_destination_pairs):
return False
return True
@@ -132,10 +132,10 @@
def _all_devices_match(value_destination_pairs):
- if not all([_devices_match(v, d) for v, d in value_destination_pairs]):
+ if not all(_devices_match(v, d) for v, d in value_destination_pairs):
return False
- if not all([_devices_match(v, value_destination_pairs[0][0])
- for v, _ in value_destination_pairs[1:]]):
+ if not all(_devices_match(v, value_destination_pairs[0][0])
+ for v, _ in value_destination_pairs[1:]):
return False
return True
@@ -401,7 +401,7 @@
# all gradient shapes are defined, we use another method to get the
# total size.
# TODO(yuefengz): move this logic to array_ops.size.
- if all([g.shape.is_fully_defined() for g, _ in device_grads_and_vars]):
+ if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
total_grad_size = sum(
[g.shape.num_elements() for g, _ in device_grads_and_vars])
else:
@@ -916,15 +916,15 @@
def choose_the_best(devices, session_config=None):
- """Find the best subclass of CrossDeviceOps given a tensorflow session.
+ """Find the best subclass of CrossDeviceOps given a session config.
Args:
- devices: a list of devices passed for distribute strategy.
- session_config: a tensorflow session config or None. If None, it will make
- deciesion based on all local devices.
+ devices: a list of devices passed to `tf.distribute.Strategy`.
+ session_config: a `tf.ConfigProto` or `None`. If `None`, it will make
+ decision based on all local devices.
Returns:
- a subclass of CrossDeviceOps.
+ A subclass of `CrossDeviceOps`.
"""
requested_devices = set([device_util.canonicalize(d) for d in devices])
machine_devices = device_lib.list_local_devices(session_config=session_config)
@@ -937,13 +937,13 @@
"Device is available but not used by distribute strategy: %s", d.name)
if len(using_devices) != len(requested_devices):
- logging.warning("Not all devices in distribute strategy are visible by "
- "TensorFlow sessions.")
+ logging.warning("Not all devices in `tf.distribute.Strategy` are visible "
+ "to TensorFlow.")
return ReductionToOneDeviceCrossDeviceOps()
- if any([d.device_type.lower() != "gpu" for d in using_devices]):
- logging.warning("Not all devices in DistributionStrategy are visible to "
- "TensorFlow session.")
+ if any(d.device_type.lower() != "gpu" for d in using_devices):
+ logging.warning("Not all devices in `tf.distribute.Strategy` are visible "
+ "to TensorFlow.")
return ReductionToOneDeviceCrossDeviceOps()
device_links = [[] for _ in range(len(using_devices))]
diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py
index 7903992..0faadd7 100644
--- a/tensorflow/python/distribute/cross_device_utils.py
+++ b/tensorflow/python/distribute/cross_device_utils.py
@@ -420,7 +420,7 @@
Returns:
list of reduced tensors
"""
- alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']])
+ alg_contains_shuffle = any(n in alg for n in ['pscpu', 'psgpu'])
is_hierarchical = '/' in alg
if 'pscpu' in alg:
aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes]
diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/distribute/device_util.py
similarity index 100%
rename from tensorflow/python/training/device_util.py
rename to tensorflow/python/distribute/device_util.py
diff --git a/tensorflow/python/training/device_util_test.py b/tensorflow/python/distribute/device_util_test.py
similarity index 98%
rename from tensorflow/python/training/device_util_test.py
rename to tensorflow/python/distribute/device_util_test.py
index cdbb082..baecd43 100644
--- a/tensorflow/python/training/device_util_test.py
+++ b/tensorflow/python/distribute/device_util_test.py
@@ -18,10 +18,10 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.distribute import device_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-from tensorflow.python.training import device_util
class DeviceUtilTest(test.TestCase):
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index 07d291e..c0f9b8a 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -245,7 +245,7 @@
else:
session_config = self._session_config
- if not self._strategy or self._strategy.should_init:
+ if not self._strategy or self._strategy.extended.experimental_should_init:
logging.info("Creating chief session creator with config: %r", config)
return monitored_session.ChiefSessionCreator(
scaffold,
@@ -306,19 +306,19 @@
return self._num_workers
@property
- def should_init(self):
+ def experimental_should_init(self):
"""Whether to run init ops."""
- return self._strategy.should_init
+ return self._strategy.extended.experimental_should_init
@property
def should_checkpoint(self):
"""Whether to save checkpoint."""
- return self._strategy.should_checkpoint
+ return self._strategy.extended.should_checkpoint
@property
def should_save_summary(self):
"""Whether to save summaries."""
- return self._strategy.should_save_summary
+ return self._strategy.extended.should_save_summary
def _run_single_worker(worker_fn,
@@ -632,10 +632,10 @@
The `strategy` object is expected to be a DistributionStrategy object which
has implemented methods needed by distributed coordinator such as
`configure(session_config, cluster_spec, task_type, task_id)` which configures
- the strategy object for a specific task and `should_init` property which
- instructs the distribute coordinator whether to run init ops for a task. The
- distribute coordinator will make a copy of the `strategy` object, call its
- `configure` method and pass it to `worker_fn` as an argument.
+ the strategy object for a specific task and `experimental_should_init`
+ property which instructs the distribute coordinator whether to run init ops
+ for a task. The distribute coordinator will make a copy of the `strategy`
+ object, call its `configure` method and pass it to `worker_fn` as an argument.
The `worker_fn` defines the training logic and is called under a its own
worker context which can be accessed to via `get_current_worker_context`. A
@@ -758,7 +758,7 @@
# The client must know the cluster but servers in the cluster don't have to
# know the client.
if task_type in [_TaskType.CLIENT, None]:
- if strategy.between_graph:
+ if strategy.extended.experimental_between_graph:
return _run_between_graph_client(worker_fn, strategy, eval_fn,
eval_strategy, cluster_spec,
session_config, rpc_layer)
@@ -804,7 +804,7 @@
environment=environment)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
- if strategy.between_graph:
+ if strategy.extended.experimental_between_graph:
# All jobs run `worker_fn` if between-graph.
_run_single_worker(worker_fn, strategy, cluster_spec, task_type,
task_id, session_config, rpc_layer)
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 0c1ee8c..f2cb950 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -79,6 +79,19 @@
return target
+class MockExtended(object):
+
+ def __init__(self,
+ between_graph=False,
+ should_init=None,
+ should_checkpoint=None,
+ should_save_summary=None):
+ self.experimental_between_graph = between_graph
+ self.experimental_should_init = should_init
+ self.should_checkpoint = should_checkpoint
+ self.should_save_summary = should_save_summary
+
+
class MockStrategy(object):
def __init__(self,
@@ -86,39 +99,33 @@
should_init=None,
should_checkpoint=None,
should_save_summary=None):
- self._between_graph = between_graph
- self._should_init = should_init
- self._should_checkpoint = should_checkpoint
- self._should_save_summary = should_save_summary
-
- @property
- def between_graph(self):
- return self._between_graph
+ self.extended = MockExtended(between_graph, should_init, should_checkpoint,
+ should_save_summary)
def configure(self,
session_config=None,
cluster_spec=None,
task_type=None,
task_id=None):
- if self._should_init is None:
+ if self.extended.experimental_should_init is None:
if task_id == 0:
- self._should_init = True
+ self.extended.experimental_should_init = True
else:
- self._should_init = False
- if self._should_checkpoint is None:
+ self.extended.experimental_should_init = False
+ if self.extended.should_checkpoint is None:
if task_id == 0:
- self._should_checkpoint = True
+ self.extended.should_checkpoint = True
else:
- self._should_checkpoint = False
- if self._should_save_summary is None:
+ self.extended.should_checkpoint = False
+ if self.extended.should_save_summary is None:
if task_id == 0:
- self._should_save_summary = True
+ self.extended.should_save_summary = True
else:
- self._should_save_summary = False
+ self.extended.should_save_summary = False
if session_config:
if (cluster_spec and task_type and task_id is not None and
- self._between_graph):
+ self.extended.experimental_between_graph):
session_config.intra_op_parallelism_threads += 1
if task_type in ["chief", "worker"]:
session_config.device_filters.extend(
@@ -127,18 +134,6 @@
session_config.inter_op_parallelism_threads += 1
session_config.device_filters.append("/job:somejob")
- @property
- def should_init(self):
- return self._should_init
-
- @property
- def should_checkpoint(self):
- return self._should_checkpoint
-
- @property
- def should_save_summary(self):
- return self._should_save_summary
-
class MockServer(object):
@@ -373,9 +368,12 @@
context = distribute_coordinator_context.get_current_worker_context()
self.assertTrue(context is not None)
- self.assertEqual(context._strategy.should_init, strategy.should_init)
- self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
- self.assertEqual(context.should_save_summary, strategy.should_save_summary)
+ self.assertEqual(context._strategy.extended.experimental_should_init,
+ strategy.extended.experimental_should_init)
+ self.assertEqual(context.should_checkpoint,
+ strategy.extended.should_checkpoint)
+ self.assertEqual(context.should_save_summary,
+ strategy.extended.should_save_summary)
task_type = str(context.task_type)
task_id = context.task_id or 0
@@ -385,7 +383,8 @@
while len(self._strategy_property[task_type]) <= task_id:
self._strategy_property[task_type].append(None)
self._strategy_property[task_type][task_id] = (
- context._strategy.should_init, context.should_checkpoint,
+ context._strategy.extended.experimental_should_init,
+ context.should_checkpoint,
context.should_save_summary)
def _run_mock_std_server(self,
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
new file mode 100644
index 0000000..a1f03ea
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -0,0 +1,1665 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Library for running a computation across multiple devices."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import threading
+import weakref
+import enum
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribution_strategy_context
+from tensorflow.python.distribute import reduce_util
+from tensorflow.python.eager import context as eager_context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses_impl
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
+from tensorflow.tools.docs import doc_controls
+
+
+# ------------------------------------------------------------------------------
+# Context tracking whether in a strategy.update() or .update_non_slot() call.
+
+
+_update_device = threading.local()
+
+
+def get_update_device():
+ """Get the current device if in a `tf.distribute.Strategy.update()` call."""
+ try:
+ return _update_device.current
+ except AttributeError:
+ return None
+
+
+class UpdateContext(object):
+ """Context manager when you are in `update()` or `update_non_slot()`."""
+
+ def __init__(self, device):
+ self._device = device
+ self._old_device = None
+
+ def __enter__(self):
+ self._old_device = get_update_device()
+ _update_device.current = self._device
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ del exception_type, exception_value, traceback
+ _update_device.current = self._old_device
+
+
+# ------------------------------------------------------------------------------
+# Public utility functions.
+
+
+@tf_export("distribute.get_loss_reduction")
+def get_loss_reduction():
+ """`tf.distribute.ReduceOp` corresponding to the last loss reduction."""
+ loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
+ if loss_reduction == losses_impl.Reduction.SUM:
+ return reduce_util.ReduceOp.SUM
+ return reduce_util.ReduceOp.MEAN
+
+
+# ------------------------------------------------------------------------------
+# Internal API for validating the current thread mode
+
+
+def _require_cross_replica_context_extended(extended):
+ """Verify in cross-replica context."""
+ context = _get_per_thread_mode()
+ cross_replica = context.cross_replica_context
+ if cross_replica is not None and cross_replica.extended is extended:
+ return
+ strategy = extended._container_strategy() # pylint: disable=protected-access
+ # We have an error to report, figure out the right message.
+ if context.distribution_strategy is not strategy:
+ _wrong_strategy_scope(strategy, context)
+ assert cross_replica is None
+ raise RuntimeError("Method requires being in cross-replica context, use "
+ "get_replica_context().merge_call()")
+
+
+def _wrong_strategy_scope(strategy, context):
+ # Figure out the right error message.
+ if not distribution_strategy_context.has_distribution_strategy():
+ raise RuntimeError(
+ 'Need to be inside "with strategy.scope()" for %s' %
+ (strategy,))
+ else:
+ raise RuntimeError(
+ "Mixing different tf.distribute.Strategy objects: %s is not %s" %
+ (context.distribution_strategy, strategy))
+
+
+def require_replica_context(replica_ctx):
+ """Verify in `replica_ctx` replica context."""
+ context = _get_per_thread_mode()
+ if context.replica_context is replica_ctx: return
+ # We have an error to report, figure out the right message.
+ if context.replica_context is None:
+ raise RuntimeError("Need to be inside `call_for_each_replica()`")
+ if context.distribution_strategy is replica_ctx.distribution_strategy:
+ # Two different ReplicaContexts with the same tf.distribute.Strategy.
+ raise RuntimeError("Mismatching ReplicaContext.")
+ raise RuntimeError(
+ "Mismatching tf.distribute.Strategy objects: %s is not %s." %
+ (context.distribution_strategy, replica_ctx.distribution_strategy))
+
+
+def _require_distribution_strategy_scope_strategy(strategy):
+ """Verify in a `strategy.scope()` in this thread."""
+ context = _get_per_thread_mode()
+ if context.distribution_strategy is strategy: return
+ _wrong_strategy_scope(strategy, context)
+
+
+def _require_distribution_strategy_scope_extended(extended):
+ """Verify in a `distribution_strategy.scope()` in this thread."""
+ context = _get_per_thread_mode()
+ if context.distribution_strategy.extended is extended: return
+ # Report error.
+ strategy = extended._container_strategy() # pylint: disable=protected-access
+ _wrong_strategy_scope(strategy, context)
+
+
+# ------------------------------------------------------------------------------
+# Internal context managers used to implement the DistributionStrategy
+# base class
+
+
+class _CurrentDistributionContext(object):
+ """Context manager setting the current `tf.distribute.Strategy`.
+
+ Also: overrides the variable creator and optionally the current device.
+ """
+
+ def __init__(self,
+ strategy,
+ var_creator_scope,
+ var_scope=None,
+ default_device=None):
+ self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
+ strategy)
+ self._var_creator_scope = var_creator_scope
+ self._var_scope = var_scope
+ if default_device:
+ self._device_scope = ops.device(default_device)
+ else:
+ self._device_scope = None
+
+ def __enter__(self):
+ _push_per_thread_mode(self._context)
+ if self._var_scope:
+ self._var_scope.__enter__()
+ self._var_creator_scope.__enter__()
+ if self._device_scope:
+ self._device_scope.__enter__()
+ return self._context.distribution_strategy
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ if self._device_scope:
+ self._device_scope.__exit__(exception_type, exception_value, traceback)
+ self._var_creator_scope.__exit__(exception_type, exception_value, traceback)
+ if self._var_scope:
+ self._var_scope.__exit__(exception_type, exception_value, traceback)
+ _pop_per_thread_mode()
+
+
+class _SameScopeAgainContext(object):
+ """Trivial context manager when you are already in `scope()`."""
+
+ def __init__(self, strategy):
+ self._distribution_strategy = strategy
+
+ def __enter__(self):
+ return self._distribution_strategy
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ del exception_type, exception_value, traceback
+
+
+# TODO(yuefengz): add more replication modes.
+@tf_export("distribute.InputReplicationMode")
+class InputReplicationMode(enum.Enum):
+ """Replication mode for input function."""
+
+ # The input function will be called on each worker independently, creating as
+ # many input pipelines as number of workers. Replicas will dequeue from the
+ # local Dataset on their worker. Distribution Strategy doesn't manage any
+ # state sharing between such separate input pipelines.
+ PER_WORKER = "PER_WORKER"
+
+
+@tf_export("distribute.InputContext")
+class InputContext(object):
+ """A class wrapping information needed by an input function.
+
+ This is a context class that is passed to the user's input fn and contains
+ information about the compute replicas and input pipelines. The number of
+ compute replicas (in sync training) helps compute per input pipeline batch
+ size from the desired global batch size. Input pipeline information can be
+ used to return a different subset of the input in each input pipeline (for
+ e.g. shard the input pipeline, use a different input source etc).
+ """
+
+ def __init__(self,
+ num_input_pipelines=1,
+ input_pipeline_id=0,
+ num_replicas_in_sync=1):
+ """Initializes an InputContext object.
+
+ Args:
+ num_input_pipelines: the number of input pipelines in a cluster.
+ input_pipeline_id: the current input pipeline id, should be an int in
+ [0,`num_input_pipelines`).
+ num_replicas_in_sync: the number of replicas that are in sync.
+ """
+ self._num_input_pipelines = num_input_pipelines
+ self._input_pipeline_id = input_pipeline_id
+ self._num_replicas_in_sync = num_replicas_in_sync
+
+ @property
+ def num_replicas_in_sync(self):
+ """Returns the number of compute replicas in sync."""
+ return self._num_replicas_in_sync
+
+ @property
+ def input_pipeline_id(self):
+ """Returns the input pipeline ID."""
+ return self._input_pipeline_id
+
+ @property
+ def num_input_pipelines(self):
+ """Returns the number of input pipelines."""
+ return self._num_input_pipelines
+
+ def get_per_replica_batch_size(self, global_batch_size):
+ """Returns the per-replica batch size.
+
+ Args:
+ global_batch_size: the global batch size which should be divisible by
+ `num_replicas_in_sync`.
+
+ Returns:
+ the per-replica batch size.
+
+ Raises:
+ ValueError: if `global_batch_size` not divisible by
+ `num_replicas_in_sync`.
+ """
+ if global_batch_size % self._num_replicas_in_sync != 0:
+ raise ValueError("The `global_batch_size` %r is not divisible by "
+ "`num_replicas_in_sync` %r " %
+ (global_batch_size, self._num_replicas_in_sync))
+ return global_batch_size // self._num_replicas_in_sync
+
+
+# ------------------------------------------------------------------------------
+# Base classes for all distribution strategies.
+
+
+@tf_export("distribute.Strategy")
+class DistributionStrategy(object):
+ """A list of devices with a state & compute distribution policy.
+
+ See [tensorflow/contrib/distribute/README.md](
+ https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
+ for overview and examples.
+ """
+
+ # TODO(josh11b): Raise an exception if variable partitioning requested before
+ # we add support.
+ # TODO(josh11b): Also `parameter_device_index` property?
+ # TODO(josh11b): `map()`
+ # TODO(josh11b): ClusterSpec/ClusterResolver
+ # TODO(josh11b): Partitioned computations, state; sharding
+ # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
+ # TODO(josh11b): List of replicas with their worker and parameter devices
+ # (where the parameter devices may overlap in the ps case).
+
+ def __init__(self, extended):
+ self._extended = extended
+
+ @property
+ def extended(self):
+ """`tf.distribute.StrategyExtended` with additional methods."""
+ return self._extended
+
+ def scope(self):
+ """Returns a context manager selecting this Strategy as current.
+
+ Inside a `with strategy.scope():` code block, this thread
+ will use a variable creator set by `strategy`, and will
+ enter its "cross-replica context".
+
+ Returns:
+ A context manager.
+ """
+ return self._extended._scope(self) # pylint: disable=protected-access
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def read_var(self, v):
+ """DEPRECATED: use extended.read_var() instead."""
+ return self._extended.read_var(v)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def colocate_vars_with(self, colocate_with_variable):
+ """DEPRECATED: use extended.colocate_vars_with() instead."""
+ return self._extended.colocate_vars_with(colocate_with_variable)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED
+ def distribute_dataset(self, dataset_fn):
+ """Return a `dataset` split across all replicas. DEPRECATED.
+
+ DEPRECATED: Please use `make_dataset_iterator` or
+ `make_input_fn_iterator` instead.
+
+ Suitable for providing input to `extended.call_for_each_replica()` by
+ creating an iterator:
+
+ ```
+ def dataset_fn():
+ return tf.data.Dataset.from_tensors([[1.]]).repeat()
+
+ with strategy.scope():
+ distributed_dataset = strategy.distribute_dataset(dataset_fn)
+ iterator = distributed_dataset.make_initializable_iterator()
+ replica_results = strategy.extended.call_for_each_replica(
+ replica_fn, args=(iterator.get_next(),))
+ ```
+
+ Args:
+ dataset_fn: A function that returns a `tf.data.Dataset`.
+
+ Returns:
+ A `PerReplicaDataset` that will produce data for each replica.
+ """
+ return self._extended._distribute_dataset(dataset_fn) # pylint: disable=protected-access
+
+ def make_dataset_iterator(self, dataset):
+ """Makes an iterator for input provided via input_dataset.
+
+ Data from the given dataset will be distributed evenly across all the
+ compute replicas. We will assume that the input dataset is batched by the
+ global batch size. With this assumption, we will make a best effort to
+ divide each batch across all the replicas (one or more workers).
+ If this effort fails, an error will be thrown, and the user should instead
+ use `make_input_fn_iterator` which provides more control to the user, and
+ does not try to divide a batch across replicas.
+
+ The user could also use `make_input_fn_iterator` if they want to
+ customize which input is fed to which replica/worker etc.
+
+ Args:
+ dataset: `tf.data.Dataset` that will be distributed evenly across all
+ replicas.
+
+ Returns:
+ An `tf.distribute.InputIterator` which returns inputs for each step of the
+ computation. User should call `initialize` on the returned iterator.
+ """
+ return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
+
+ def make_input_fn_iterator(self,
+ input_fn,
+ replication_mode=InputReplicationMode.PER_WORKER):
+ """Returns an iterator split across replicas created from an input function.
+
+ The `input_fn` should take an `tf.distribute.InputContext` object where
+ information about input sharding can be accessed:
+
+ ```
+ def input_fn(input_context):
+ d = tf.data.Dataset.from_tensors([[1.]]).repeat()
+ return d.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+ with strategy.scope():
+ iterator = strategy.make_input_fn_iterator(
+ input_fn)
+ replica_results = strategy.extended.call_for_each_replica(
+ replica_fn, iterator.get_next())
+ ```
+
+ Args:
+ input_fn: A function that returns a `tf.data.Dataset`. This function is
+ expected to take an `tf.distribute.InputContext` object.
+ replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
+ Only `PER_WORKER` is supported currently.
+
+ Returns:
+ An iterator object that can be initialized and fetched next element.
+ """
+ if replication_mode != InputReplicationMode.PER_WORKER:
+ raise ValueError(
+ "Input replication mode not supported: %r" % replication_mode)
+ return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
+ input_fn, replication_mode=replication_mode)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def broadcast(self, tensor, destinations=None):
+ """DEPRECATED: use extended.broadcast_to() instead."""
+ return self._extended.broadcast_to(tensor, destinations)
+
+ @doc_controls.do_not_generate_docs # Use experimental_initialize() instead.
+ def initialize(self):
+ """DEPRECATED: Use `experimental_initialize()` instead."""
+ return self._extended._initialize() # pylint: disable=protected-access
+
+ def experimental_initialize(self):
+ """Any initialization to be done before running any computations.
+
+ In eager mode, it executes any initialization as a side effect.
+ In graph mode, it creates the initialization ops and returns them.
+
+ For example, TPU initialize_system ops.
+
+ Returns:
+ A list of ops to execute.
+ """
+ return self._extended._initialize() # pylint: disable=protected-access
+
+ @doc_controls.do_not_generate_docs # Use experimental_finalize() instead.
+ def finalize(self):
+ """DEPRECATED: Use `experimental_finalize()` instead."""
+ return self._extended._finalize() # pylint: disable=protected-access
+
+ def experimental_finalize(self):
+ """Any final actions to be done at the end of all computations.
+
+ In eager mode, it executes any finalize actions as a side effect.
+ In graph mode, it creates the finalize ops and returns them.
+
+ For example, TPU shutdown ops.
+
+ Returns:
+ A list of ops to execute.
+ """
+ return self._extended._finalize() # pylint: disable=protected-access
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def run_steps_on_dataset(self, fn, iterator, iterations=1,
+ initial_loop_values=None):
+ """DEPRECATED: use extended.experimental_run_steps_on_iterator() instead."""
+ return self._extended.experimental_run_steps_on_iterator(
+ fn, iterator, iterations, initial_loop_values)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def call_for_each_replica(self, fn, *args, **kwargs):
+ """DEPRECATED: use extended.call_for_each_replica() instead."""
+ # Handle old *args, **kwargs, and new args=(...), kwargs={...}, to
+ # allow transition.
+ a = kwargs.pop("args", None)
+ if a is not None:
+ if args:
+ raise ValueError(
+ "Can't pass *args and args=... to call_for_each_replica")
+ args = a
+ k = kwargs.pop("kwargs", None)
+ if k is not None:
+ if kwargs:
+ raise ValueError(
+ "Can't pass **kwargs and kwargs=... to call_for_each_replica")
+ kwargs = k
+ kwargs.pop("run_concurrently", None) # Ignore old option.
+ return self._extended.call_for_each_replica(fn, args, kwargs)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def reduce(self, aggregation, value, destinations):
+ """DEPRECATED: use extended.reduce_to() instead."""
+ return self._extended.reduce_to(aggregation, value, destinations)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def batch_reduce(self, aggregation, value_destination_pairs):
+ """DEPRECATED: use extended.batch_reduce_to() instead."""
+ return self._extended.batch_reduce_to(aggregation, value_destination_pairs)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def update(self, var, fn, *args, **kwargs):
+ """DEPRECATED: use extended.update() instead."""
+ group = kwargs.pop("group", True)
+ # We temporarily support "grouped" in addition to "group" for backward-
+ # compatibility.
+ group = kwargs.pop("grouped", True) and group
+ # Handle old *args, **kwargs, and new args=(...), kwargs={...}, to
+ # allow transition.
+ a = kwargs.pop("args", None)
+ if a is not None:
+ if args:
+ raise ValueError(
+ "Can't pass *args and args=... to update")
+ args = a
+ k = kwargs.pop("kwargs", None)
+ if k is not None:
+ if kwargs:
+ raise ValueError(
+ "Can't pass **kwargs and kwargs=... to update")
+ kwargs = k
+ return self._extended.update(var, fn, args, kwargs, group)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ """DEPRECATED: use extended.update_non_slot() instead."""
+ group = kwargs.pop("group", True)
+ # We temporarily support "grouped" in addition to "group" for backward-
+ # compatibility.
+ group = kwargs.pop("grouped", True) and group
+ # Handle old *args, **kwargs, and new args=(...), kwargs={...}, to
+ # allow transition.
+ a = kwargs.pop("args", None)
+ if a is not None:
+ if args:
+ raise ValueError(
+ "Can't pass *args and args=... to update_non_slot")
+ args = a
+ k = kwargs.pop("kwargs", None)
+ if k is not None:
+ if kwargs:
+ raise ValueError(
+ "Can't pass **kwargs and kwargs=... to update_non_slot")
+ kwargs = k
+ return self._extended.update_non_slot(
+ colocate_with, fn, args, kwargs, group)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, -> `DistributedValues`
+ def unwrap(self, value):
+ """Returns the list of all per-replica values contained in `value`.
+
+ Args:
+ value: A value returned by `extended.call_for_each_replica()` or a
+ variable created in `scope`.
+
+ Returns:
+ A list of values contained in `value`. If `value` represents a single
+ value, this returns `[value].`
+ """
+ return self._extended._unwrap(value) # pylint: disable=protected-access
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def value_container(self, value):
+ """DEPRECATED: use extended.value_container() instead."""
+ return self._extended.value_container(value)
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, -> `DistributedValues`
+ def group(self, value, name=None):
+ """Shortcut for `tf.group(self.unwrap(value))`."""
+ return self._extended._group(value, name) # pylint: disable=protected-access
+
+ @property
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def require_static_shapes(self):
+ """DEPRECATED: use extended.require_static_shapes instead."""
+ return self._extended.experimental_require_static_shapes
+
+ @property
+ def num_replicas_in_sync(self):
+ """Returns number of replicas over which gradients are aggregated."""
+ return self._extended._num_replicas_in_sync # pylint: disable=protected-access
+
+ @property
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def worker_devices(self):
+ """DEPRECATED: use extended.worker_devices instead."""
+ return self._extended.worker_devices
+
+ @property
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def parameter_devices(self):
+ """DEPRECATED: use extended.parameter_devices instead."""
+ return self._extended.parameter_devices
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def non_slot_devices(self, var_list):
+ """DEPRECATED: use extended.non_slot_devices instead."""
+ return self._extended.non_slot_devices(var_list)
+
+ @property
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def between_graph(self):
+ """DEPRECATED: use extended.experimental_between_graph instead."""
+ return self._extended.experimental_between_graph
+
+ @doc_controls.do_not_generate_docs # DEPRECATED, being replaced by a new API.
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ # pylint: disable=g-doc-return-or-yield,g-doc-args
+ """DEPRECATED: use `update_config_proto` instead.
+
+ Configures the strategy class.
+
+ DEPRECATED: This method's functionality has been split into the strategy
+ constructor and `update_config_proto`. In the future, we will allow passing
+ cluster and config_proto to the constructor to configure the strategy. And
+ `update_config_proto` can be used to update the config_proto based on the
+ specific strategy.
+ """
+ return self._extended._configure( # pylint: disable=protected-access
+ session_config, cluster_spec, task_type, task_id)
+
+ def update_config_proto(self, config_proto):
+ """Returns a copy of `config_proto` modified for use with this strategy.
+
+ The updated config has something needed to run a strategy, e.g.
+ configuration to run collective ops, or device filters to improve
+ distributed training performance.
+
+ Args:
+ config_proto: a `tf.ConfigProto` object.
+
+ Returns:
+ The updated copy of the `config_proto`.
+ """
+ return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access
+
+ @property
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def should_init(self):
+ """DEPRECATED: use extended.should_init instead."""
+ return self._extended.experimental_should_init
+
+ @property
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def should_checkpoint(self):
+ """DEPRECATED: use extended.should_checkpoint instead."""
+ return self._extended.should_checkpoint
+
+ @property
+ @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
+ def should_save_summary(self):
+ """DEPRECATED: use extended.should_save_summary instead."""
+ return self._extended.should_save_summary
+
+ def __deepcopy__(self, memo):
+ # First do a regular deepcopy of `self`.
+ cls = self.__class__
+ result = cls.__new__(cls)
+ memo[id(self)] = result
+ for k, v in self.__dict__.items():
+ setattr(result, k, copy.deepcopy(v, memo))
+ # One little fix-up: we want `result._extended` to reference `result`
+ # instead of `self`.
+ result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access
+ return result
+
+ def __copy__(self):
+ raise RuntimeError("Must only deepcopy DistributionStrategy.")
+
+
+@tf_export("distribute.StrategyExtended")
+class DistributionStrategyExtended(object):
+ """Additional APIs for algorithms that need to be distribution-aware.
+
+ The intent is that you can write an algorithm in a stylized way and
+ it will be usable with a variety of different
+ `tf.distribute.Strategy`
+ implementations. Each descendant will implement a different strategy
+ for distributing the algorithm across multiple devices/machines.
+ Furthermore, these changes can be hidden inside the specific layers
+ and other library classes that need special treatment to run in a
+ distributed setting, so that most users' model definition code can
+ run unchanged. The `tf.distribute.Strategy` API works the same way
+ with eager and graph execution.
+
+ First let's introduce a few high-level concepts:
+
+ * _Data parallelism_ is where we run multiple copies of the model
+ on different slices of the input data. This is in contrast to
+ _model parallelism_ where we divide up a single copy of a model
+ across multiple devices.
+ Note: we only support data parallelism for now, but
+ hope to add support for model parallelism in the future.
+ * A _replica_ is one copy of the model, running on one slice of the
+ input data.
+ * _Synchronous_, or more commonly _sync_, training is where the
+ updates from each replica are aggregated together before updating
+ the model variables. This is in contrast to _asynchronous_, or
+ _async_ training, where each replica updates the model variables
+ independently.
+ * Furthermore you might run your computation on multiple devices
+ on one machine (or "host"), or on multiple machines/hosts.
+ If you are running on multiple machines, you might have a
+ single master host that drives computation across all of them,
+ or you might have multiple clients driving the computation
+ asynchronously.
+
+ To distribute an algorithm, we might use some of these ingredients:
+
+ * Parameter servers: These are hosts that hold a single copy of
+ parameters/variables. All replicas that want to operate on a variable
+ retrieve it at the beginning of a step and send an update to be
+ applied at the end of the step. Can support either sync or async
+ training.
+ * Mirrored variables: These are variables that are copied to multiple
+ devices, where we keep the copies in sync by applying the same
+ updates to every copy. Normally would only be used with sync training.
+ * Reductions and Allreduce: A _reduction_ is some method of
+ aggregating multiple values into one value, like "sum" or
+ "mean". If doing sync training, we will perform a reduction on the
+ gradients to a parameter from all replicas before applying the
+ update. Allreduce is an algorithm for performing a reduction on
+ values from multiple devices and making the result available on
+ all of those devices.
+ * In the future we will have support for TensorFlow's partitioned
+ variables, where a single variable is split across multiple
+ devices.
+
+ We have then a few approaches we want to support:
+
+ * Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
+ This code should work as before, even if some of the layers, etc.
+ used by that code are written to be distribution-aware. This is done
+ by having a default `tf.distribute.Strategy` that gives ordinary behavior,
+ and by default being in a single replica context.
+ * Ordinary model code that you want to run using a specific
+ `tf.distribute.Strategy`. This can be as simple as:
+
+ ```
+ with my_strategy.scope():
+ iterator = my_strategy.make_dataset_iterator(dataset)
+ session.run(iterator.initialize())
+ replica_train_ops = my_strategy.extended.call_for_each_replica(
+ replica_fn, args=(iterator.get_next(),))
+ train_op = my_strategy.group(replica_train_ops)
+ ```
+
+ This takes an ordinary `dataset` and `replica_fn` and runs it
+ distributed using a particular `tf.distribute.Strategy` in
+ `my_strategy`. Any variables created in `replica_fn` are created
+ using `my_strategy`'s policy, and library functions called by
+ `replica_fn` can use the `get_replica_context()` API to get enhanced
+ behavior in this case.
+
+ * If you want to write a distributed algorithm, you may use any of
+ the `tf.distribute.Strategy` APIs inside a
+ `with my_strategy.scope():` block of code.
+
+ Lower-level concepts:
+
+ * Wrapped values: In order to represent values parallel across devices
+ (either replicas or the devices associated with a particular value), we
+ wrap them in a "PerReplica" or "Mirrored" object that contains a map
+ from device to values. "PerReplica" is used when the value may be
+ different across replicas, and "Mirrored" when the value are the same.
+ * Unwrapping and merging: Consider calling a function `fn` on multiple
+ replicas, like `extended.call_for_each_replica(fn, args=[w])` with an
+ argument `w` that is a wrapped value. This means `w` will have a map taking
+ replica device `d0` to `w0`, replica device `d1` to `w1`,
+ etc. `extended.call_for_each_replica()` unwraps `w` before calling `fn`, so
+ it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return
+ values from `fn()`, which can possibly result in wrapped values. For
+ example, let's say `fn()` returns a tuple with three components: `(x, a,
+ v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component
+ is the same object `x` from every replica, then the first component of the
+ merged result will also be `x`. If the second component is different (`a`,
+ `b`, ...) from each replica, then the merged value will have a wrapped map
+ from replica device to the different values. If the third component is the
+ members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.),
+ then the merged result will be that mirrored variable (`v`).
+ * Replica context vs. Cross-replica context: _replica context_ is when we
+ are in some function that is being called once for each replica.
+ Otherwise we are in cross-replica context, which is useful for
+ calling `tf.distribute.Strategy` methods which operate across the
+ replicas (like `reduce_to()`). By default you start in a replica context
+ (the default "single replica context") and then some methods can
+ switch you back and forth, as described below.
+ * Worker devices vs. parameter devices: Most replica computations will
+ happen on worker devices. Since we don't yet support model
+ parallelism, there will be one worker device per replica. When using
+ parameter servers (see above), the set of devices holding
+ variables may be different, otherwise the parameter devices might
+ match the worker devices.
+ * Non-slot devices are some subset of the parameter devices where we
+ put all the non-slot variables. We need to ensure that all
+ non-slot variables are allocated on the same device, or mirrored
+ across the same set of devices. If you have some variable you want
+ to colocate all the non-slot variables with, you can use
+ `colocate_vars_with()` to get the remaining non-slot variables on
+ the same device. Otherwise you can use `non_slot_devices()` to
+ pick a consistent set of devices to pass to both
+ `colocate_vars_with()` and `update_non_slot()`.
+
+ When using a `tf.distribute.Strategy`, we have a new type dimension
+ called _locality_ that says what values are compatible with which
+ APIs:
+
+ * T: different value for each replica (e.g. a PerReplica-wrapped value).
+ * M: value is "mirrored" across replicas, i.e. there are copies with the
+ same value on each replica (e.g. a Mirrored-wrapped value).
+ * V(`v`): value is "mirrored" across all the devices which have a
+ copy of variable `v` (also a Mirrored-wrapped value, but over
+ parameter devices instead of worker devices).
+ * N: value is "mirrored" across all the "non-slot" devices
+
+ Rules for methods with respect to locality and single-replica vs.
+ cross-replica context:
+
+ * `with d.scope()`: default single-replica context -> cross-replica context
+ for `d`
+ * `with d.extended.colocate_vars_with(v)`: in replica/cross-replica context,
+ variables will be created with locality V(`v`). That is, if we write
+ `with d.extended.colocate_vars_with(v1): v2 = tf.get_variable(...)`,
+ then `v2` will have locality V(`v1`), i.e. locality V(`v2`) will equal
+ V(`v1`).
+ * `with d.extended.colocate_vars_with(d.extended.non_slot_devices(...))`: in
+ replica/cross-replica context, variables will be created with locality N
+ * `v = tf.get_variable(...)`: in replica/cross-replica context, creates
+ a variable (which by definition will have locality V(`v`), though
+ will match another locality if inside a `colocate_vars_with`
+ scope).
+ * `d.make_dataset_iterator(dataset)` (or the deprecated
+ `d.distribute_dataset(dataset).make_one_shot_iterator()`): in cross-replica
+ context, produces an iterator with locality T
+ * `d.extended.broadcast_to(t)`: in cross-replica context, produces a value
+ with locality M
+ * `d.extended.broadcast_to(t, v)`: in cross-replica context, produces a value
+ with locality V(`v`)
+ * `d.extended.call_for_each_replica(fn, ...)`: in cross-replica context, runs
+ `fn()` in a replica context (and so may call `get_replica_context()` and
+ use its API, including `merge_call()` to get back to cross-replica
+ context), once for each replica. May use values with locality T or
+ M, and any variable.
+ * `d.extended.reduce_to(m, t, t)`: in cross-replica context, accepts t with
+ locality T and produces a value with locality M.
+ * `d.extended.reduce_to(m, t, v)`: in cross-replica context, accepts t with
+ locality T and produces a value with locality V(`v`).
+ * `d.extended.batch_reduce_to(m, [(t, v)]): see `d.extended.reduce_to()`
+ * `d.extended.update(v, fn, ...)`: in cross-replica context, runs `fn()` once
+ for each device `v` is copied to, all inputs should have locality
+ V(`v`), output will have locality V(`v`) as well.
+ * `d.extended.update_non_slot(d.extended.non_slot_devices(), fn)`: in
+ cross-replica context, like `d.extended.update()` except with locality N.
+ * `d.extended.read_var(v)`: Gets the (read-only) value of the variable `v` (on
+ the device determined by the current device scope), aggregating
+ across replicas for replica-local variables. Frequently, this will be
+ done automatically when using `v` in an expression or fetching it in
+ a cross-replica context, but this function can be used to force that
+ conversion happens at a particular point in time (for example, to
+ add the result of the conversion to a graph collection).
+
+ The standard pattern for updating variables is to:
+
+ 1. Create an input iterator with `d.make_dataset_iterator()`.
+ 2. Define each replica `d.extended.call_for_each_replica()` up to the point of
+ getting a list of gradient, variable pairs.
+ 3. Call `d.extended.reduce_to(VariableAggregation.SUM, t, v)` or
+ `d.extended.batch_reduce_to()` to sum the gradients (with locality T)
+ into values with locality V(`v`).
+ 4. Call `d.extended.update(v)` for each variable to update its value.
+
+ Steps 3 and 4 are done automatically by class `Optimizer` if you call
+ its `apply_gradients` method in a replica context. Otherwise you can
+ manually call its `_distributed_apply` method in a cross-replica context.
+
+ Another thing you might want to do in the middle of your replica function is
+ an all-reduce of some intermediate value, using `d.extended.reduce_to()` or
+ `d.extended.batch_reduce_to()`. You simply provide the same tensor as the
+ input and destination.
+
+ Layers should expect to be called in a replica context, and can use
+ the `tf.distribute.get_replica_context` function to get a
+ `tf.distribute.ReplicaContext` object. The
+ `ReplicaContext` object has a `merge_call()` method for entering
+ cross-replica context where you can use `reduce_to()` (or
+ `batch_reduce_to()`) and then optionally `update()` to update state.
+
+ You may use this API whether or not a `tf.distribute.Strategy` is
+ being used, since there is a default implementation of
+ `ReplicaContext` and `tf.distribute.Strategy`.
+
+ NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
+ in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
+ the `tf.distribute.Strategy` subclass is for instantiating your subclass of
+ `tf.distribute.StrategyExtended` in the `__init__` method.
+ """
+
+ def __init__(self, container_strategy):
+ self._container_strategy_weakref = weakref.ref(container_strategy)
+ self._default_device = None
+ # This property is used to determine if we should set drop_remainder=True
+ # when creating Datasets from numpy array inputs.
+ self._require_static_shapes = False
+
+ def _container_strategy(self):
+ """Get the containing `DistributionStrategy`.
+
+ This should not generally be needed except when creating a new
+ `ReplicaContext` and to validate that the caller is in the correct
+ `scope()`.
+
+ Returns:
+ The `DistributionStrategy` such that `strategy.extended` is `self`.
+ """
+ container_strategy = self._container_strategy_weakref()
+ assert container_strategy is not None
+ return container_strategy
+
+ def _scope(self, strategy):
+ """Implementation of DistributionStrategy.scope()."""
+ if distribution_strategy_context.has_distribution_strategy():
+ _require_cross_replica_context_extended(self)
+ return _SameScopeAgainContext(strategy)
+
+ def creator_with_resource_vars(*args, **kwargs):
+ _require_distribution_strategy_scope_extended(self)
+ kwargs["use_resource"] = True
+ return self._create_variable(*args, **kwargs)
+
+ def distributed_getter(getter, *args, **kwargs):
+ if not self._allow_variable_partition():
+ if kwargs.pop("partitioner", None) is not None:
+ tf_logging.log_first_n(
+ tf_logging.WARN, "Partitioned variables are disabled when using "
+ "current tf.distribute.Strategy.", 1)
+ return getter(*args, **kwargs)
+
+ return _CurrentDistributionContext(
+ strategy,
+ variable_scope.variable_creator_scope(creator_with_resource_vars),
+ variable_scope.variable_scope(
+ variable_scope.get_variable_scope(),
+ custom_getter=distributed_getter), self._default_device)
+
+ def _allow_variable_partition(self):
+ return False
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ # Note: should support "colocate_with" argument.
+ raise NotImplementedError("must be implemented in descendants")
+
+ def read_var(self, v):
+ """Reads the value of a variable.
+
+ Returns the aggregate value of a replica-local variable, or the
+ (read-only) value of any other variable.
+
+ Args:
+ v: A variable allocated within the scope of this `tf.distribute.Strategy`.
+
+ Returns:
+ A tensor representing the value of `v`, aggregated across replicas if
+ necessary.
+ """
+ raise NotImplementedError("must be implemented in descendants")
+
+ def colocate_vars_with(self, colocate_with_variable):
+ """Scope that controls which devices variables will be created on.
+
+ No operations should be added to the graph inside this scope, it
+ should only be used when creating variables (some implementations
+ work by changing variable creation, others work by using a
+ tf.colocate_with() scope).
+
+ This may only be used inside `self.scope()`.
+
+ Example usage:
+
+ ```
+ with strategy.scope():
+ var1 = tf.get_variable(...)
+ with strategy.extended.colocate_vars_with(v1):
+ # var2 and var3 will be created on the same device(s) as var1
+ var2 = tf.get_variable(...)
+ var3 = tf.get_variable(...)
+
+ def fn(v1, v2, v3):
+ # operates on v1 from var1, v2 from var2, and v3 from var3
+
+ # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
+ strategy.extended.update(v1, fn, args=(v2, v3))
+ ```
+
+ Args:
+ colocate_with_variable: A created in `self.scope()`. Variables created
+ while in the returned context manager will be on the same set of
+ devices as `colocate_with_variable`.
+
+ Returns:
+ A context manager.
+ """
+ def create_colocated_variable(next_creator, *args, **kwargs):
+ _require_distribution_strategy_scope_extended(self)
+ kwargs["use_resource"] = True
+ kwargs["colocate_with"] = colocate_with_variable
+ return next_creator(*args, **kwargs)
+
+ _require_distribution_strategy_scope_extended(self)
+ return variable_scope.variable_creator_scope(create_colocated_variable)
+
+ def _call_dataset_fn(self, dataset_fn):
+ """Call the `dataset_fn` with `input_context` as argument."""
+ result = dataset_fn()
+ if not isinstance(result, dataset_ops.DatasetV2):
+ raise ValueError(
+ "dataset_fn() must return a tf.data.Dataset when using a "
+ "tf.distribute.Strategy.")
+ return result
+
+ # TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
+ # Dataset API such as make_one_shot_iterator and make_initializable_iterator.
+ # Extend to implement more functionality of datasets.
+ def _distribute_dataset(self, dataset_fn):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def _make_dataset_iterator(self, dataset):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def _make_input_fn_iterator(self, input_fn, replication_mode):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def broadcast_to(self, tensor, destinations):
+ """Mirror a tensor on one device to all worker devices.
+
+ Args:
+ tensor: A Tensor value to broadcast.
+ destinations: A mirrored variable, device string, or list of device
+ strings, specifying the destination devices to copy `tensor` to.
+
+ Returns:
+ A value mirrored to `destinations` devices.
+ """
+ # TODO(josh11b): More docstring
+ _require_cross_replica_context_extended(self)
+ return self._broadcast_to(tensor, destinations)
+
+ def _broadcast_to(self, tensor, destinations):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def _initialize(self):
+ return []
+
+ def _finalize(self):
+ return []
+
+ def experimental_run_steps_on_iterator(self, fn, iterator, iterations=1,
+ initial_loop_values=None):
+ """Run `fn` with input from `iterator` for `iterations` times.
+
+ This method can be used to run a step function for training a number of
+ times using input from a dataset.
+
+ Args:
+ fn: function to run using this distribution strategy. The function must
+ have the following signature: `def fn(context, inputs)`.
+ `context` is an instance of `MultiStepContext` that will be passed when
+ `fn` is run. `context` can be used to specify the outputs to be returned
+ from `fn` by calling `context.set_last_step_output`. It can also be used
+ to capture non tensor outputs by `context.set_non_tensor_output`.
+ See `MultiStepContext` documentation for more information.
+ `inputs` will have same type/structure as `iterator.get_next()`.
+ Typically, `fn` will use `call_for_each_replica` method of the strategy
+ to distribute the computation over multiple replicas.
+ iterator: Iterator of a dataset that represents the input for `fn`. The
+ caller is responsible for initializing the iterator as needed.
+ iterations: (Optional) Number of iterations that `fn` should be run.
+ Defaults to 1.
+ initial_loop_values: (Optional) Initial values to be passed into the
+ loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
+ initial_loop_values argument when we have a mechanism to infer the
+ outputs of `fn`.
+
+ Returns:
+ Returns the `MultiStepContext` object which has the following properties,
+ among other things:
+ - run_op: An op that runs `fn` `iterations` times.
+ - last_step_outputs: A dictionary containing tensors set using
+ `context.set_last_step_output`. Evaluating this returns the value of
+ the tensors after the last iteration.
+ - non_tensor_outputs: A dictionatry containing anything that was set by
+ `fn` by calling `context.set_non_tensor_output`.
+ """
+ _require_cross_replica_context_extended(self)
+ return self._experimental_run_steps_on_iterator(
+ fn, iterator, iterations, initial_loop_values)
+
+ def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
+ initial_loop_values):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def call_for_each_replica(self, fn, args=(), kwargs=None):
+ """Run `fn` once per replica.
+
+ `fn` may call `tf.get_replica_context()` to access methods such as
+ `replica_id_in_sync_group` and `merge_call()`.
+
+ `merge_call()` is used to communicate between the replicas and
+ re-enter the cross-replica context. All replicas pause their execution
+ having encountered a `merge_call()` call. After that the
+ `merge_fn`-function is executed. Its results are then unwrapped and
+ given back to each replica call. After that execution resumes until
+ `fn` is complete or encounters another `merge_call()`. Example:
+
+ ```python
+ # Called once in "cross-replica" context.
+ def merge_fn(distribution, three_plus_replica_id):
+ # sum the values across replicas
+ return sum(distribution.unwrap(three_plus_replica_id))
+
+ # Called once per replica in `distribution`, in a "replica" context.
+ def fn(three):
+ replica_ctx = tf.get_replica_context()
+ v = three + replica_ctx.replica_id_in_sync_group
+ # Computes the sum of the `v` values across all replicas.
+ s = replica_ctx.merge_call(merge_fn, args=(v,))
+ return s + v
+
+ with distribution.scope():
+ # in "cross-replica" context
+ ...
+ merged_results = distribution.call_for_each_replica(fn, args=[3])
+ # merged_results has the values from every replica execution of `fn`.
+ print(distribution.unwrap(merged_results)) # Prints a list
+ ```
+
+ Args:
+ fn: function to run (will be run once per replica).
+ args: Tuple or list with positional arguments for `fn`.
+ kwargs: Dict with keyword arguments for `fn`.
+
+ Returns:
+ Merged return value of `fn` across all replicas.
+ """
+ _require_cross_replica_context_extended(self)
+ if kwargs is None:
+ kwargs = {}
+ return self._call_for_each_replica(fn, args, kwargs)
+
+ def _call_for_each_replica(self, fn, args, kwargs):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def reduce_to(self, reduce_op, value, destinations):
+ """Combine (via e.g. sum or mean) values across replicas.
+
+ Args:
+ reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
+ DEPRECATED but still accepted values:
+ `tf.VariableAggregation.SUM`,
+ `tf.VariableAggregation.MEAN`,
+ value: A per-replica value with one value per replica.
+ destinations: A mirrored variable, a per-replica tensor, a device string,
+ or list of device strings. The return value will be copied to all
+ destination devices (or all the devices where the `destinations` value
+ resides). To perform an all-reduction, pass `value` to `destinations`.
+
+ Returns:
+ A value mirrored to `destinations`.
+ """
+ # TODO(josh11b): More docstring
+ # TODO(josh11b): Return an unwrapped value if colocate_with is a
+ # single device.
+ _require_cross_replica_context_extended(self)
+
+ # TODO(priyag): Remove this when all callers have been updated.
+ if isinstance(reduce_op, variable_scope.VariableAggregation):
+ assert reduce_op in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN,
+ ]
+ reduce_op = reduce_util.ReduceOp.from_variable_aggregation(reduce_op)
+ return self._reduce_to(reduce_op, value, destinations)
+
+ def _reduce_to(self, reduce_op, value, destinations):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def batch_reduce_to(self, reduce_op, value_destination_pairs):
+ """Combine multiple `reduce_to` calls into one for faster execution.
+
+ Args:
+ reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
+ DEPRECATED but still accepted values:
+ `tf.VariableAggregation.SUM`,
+ `tf.VariableAggregation.MEAN`,
+ value_destination_pairs: A sequence of (value, destinations)
+ pairs. See `reduce_to()` for a description.
+
+ Returns:
+ A list of mirrored values, one per pair in `value_destination_pairs`.
+ """
+ # TODO(josh11b): More docstring
+ _require_cross_replica_context_extended(self)
+
+ # TODO(priyag): Remove this when all callers have been updated.
+ if isinstance(reduce_op, variable_scope.VariableAggregation):
+ assert reduce_op in [
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN,
+ ]
+ reduce_op = reduce_util.ReduceOp.from_variable_aggregation(reduce_op)
+ return self._batch_reduce_to(reduce_op, value_destination_pairs)
+
+ def _batch_reduce_to(self, reduce_op, value_destination_pairs):
+ return [
+ self.reduce_to(reduce_op, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
+
+ def update(self, var, fn, args=(), kwargs=None, group=True):
+ """Run `fn` to update `var` using inputs mirrored to the same devices.
+
+ If `var` is mirrored across multiple devices, then this implements
+ logic like:
+
+ ```
+ results = {}
+ for device, v in var:
+ with tf.device(device):
+ # args and kwargs will be unwrapped if they are mirrored.
+ results[device] = fn(v, *args, **kwargs)
+ return merged(results)
+ ```
+
+ Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.
+
+ Neither `args` nor `kwargs` may contain per-replica values.
+ If they contain mirrored values, they will be unwrapped before
+ calling `fn`.
+
+ Args:
+ var: Variable, possibly mirrored to multiple devices, to operate on.
+ fn: Function to call. Should take the variable as the first argument.
+ args: Tuple or list. Additional positional arguments to pass to `fn()`.
+ kwargs: Dict with keyword arguments to pass to `fn()`.
+ group: Boolean. Defaults to True. If False, the return value will be
+ unwrapped.
+
+ Returns:
+ By default, the merged return value of `fn` across all replicas. The
+ merged result has dependencies to make sure that if it is evaluated at
+ all, the side effects (updates) will happen on every replica. If instead
+ "group=False" is specified, this function will return a nest of lists
+ where each list has an element per replica, and the caller is responsible
+ for ensuring all elements are executed.
+ """
+ _require_cross_replica_context_extended(self)
+ if kwargs is None:
+ kwargs = {}
+ return self._update(var, fn, args, kwargs, group)
+
+ def _update(self, var, fn, args, kwargs, group):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def update_non_slot(
+ self, colocate_with, fn, args=(), kwargs=None, group=True):
+ """Runs `fn(*args, **kwargs)` on `colocate_with` devices.
+
+ Args:
+ colocate_with: The return value of `non_slot_devices()`.
+ fn: Function to execute.
+ args: Tuple or list. Positional arguments to pass to `fn()`.
+ kwargs: Dict with keyword arguments to pass to `fn()`.
+ group: Boolean. Defaults to True. If False, the return value will be
+ unwrapped.
+
+ Returns:
+ Return value of `fn`, possibly merged across devices.
+ """
+ _require_cross_replica_context_extended(self)
+ if kwargs is None:
+ kwargs = {}
+ return self._update_non_slot(colocate_with, fn, args, kwargs, group)
+
+ def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def _unwrap(self, distributed_value):
+ raise NotImplementedError("must be implemented in descendants")
+
+ def value_container(self, value):
+ """Returns the container that this per-replica `value` belongs to.
+
+ Args:
+ value: A value returned by `call_for_each_replica()` or a variable
+ created in `scope()`.
+
+ Returns:
+ A container that `value` belongs to.
+ If value does not belong to any container (including the case of
+ container having been destroyed), returns the value itself.
+ `value in unwrap(value_container(value))` will always be true.
+ """
+ raise NotImplementedError("must be implemented in descendants")
+
+ def _group(self, value, name=None):
+ """Shortcut for `tf.group(distribution.unwrap(value))`."""
+ value = nest.flatten(self._unwrap(value))
+
+ if len(value) != 1 or name is not None:
+ return control_flow_ops.group(value, name=name)
+ # Special handling for the common case of one op.
+ v, = value
+ if hasattr(v, "op"):
+ v = v.op
+ return v
+
+ @property
+ def experimental_require_static_shapes(self):
+ return self._require_static_shapes
+
+ @property
+ def _num_replicas_in_sync(self):
+ """Returns number of replicas over which gradients are aggregated."""
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def worker_devices(self):
+ """Returns the list of devices used to run `call_for_each_replica()` calls.
+ """
+ # TODO(josh11b): More docstring
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def parameter_devices(self):
+ """Returns the list of devices used for variable and `update` placement."""
+ # TODO(josh11b): More docstring
+ raise NotImplementedError("must be implemented in descendants")
+
+ def non_slot_devices(self, var_list):
+ """Device(s) for non-slot variables.
+
+ Create variables on these devices in a
+ `with colocate_vars_with(non_slot_devices(...)):` block.
+ Update those using `update_non_slot()`.
+
+ Args:
+ var_list: The list of variables being optimized, needed with the
+ default `tf.distribute.Strategy`.
+ """
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def experimental_between_graph(self):
+ """Whether the strategy uses between-graph replication or not.
+
+ This is expected to return a constant value that will not be changed
+ throughout its life cycle.
+ """
+ raise NotImplementedError("must be implemented in descendants")
+
+ def _configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the strategy class."""
+ del session_config, cluster_spec, task_type, task_id
+
+ def _update_config_proto(self, config_proto):
+ return copy.deepcopy(config_proto)
+
+ @property
+ def experimental_should_init(self):
+ """Whether initialization is needed."""
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def should_checkpoint(self):
+ """Whether checkpointing is needed."""
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def should_save_summary(self):
+ """Whether saving summaries is needed."""
+ raise NotImplementedError("must be implemented in descendants")
+
+
+# A note about the difference between the context managers
+# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
+# (defined above) used by `DistributionStrategy.scope()`:
+#
+# * a ReplicaContext is only present during a `call_for_each_replica()`
+# call (except during a `merge_run` call) and in such a scope it
+# will be returned by calls to `get_replica_context()`. Implementers of new
+# DistributionStrategy descendants will frequently also need to
+# define a descendant of ReplicaContext, and are responsible for
+# entering and exiting this context.
+#
+# * DistributionStrategy.scope() sets up a variable_creator scope that
+# changes variable creation calls (e.g. to make mirrored
+# variables). This is intended as an outer scope that users enter once
+# around their model creation and graph definition. There is no
+# anticipated need to define descendants of _CurrentDistributionContext.
+# It sets the current DistributionStrategy for purposes of
+# `get_strategy()` and `has_strategy()`
+# and switches the thread mode to a "cross-replica context".
+@tf_export("distribute.ReplicaContext")
+class ReplicaContext(object):
+ """`tf.distribute.Strategy` API when in a replica context.
+
+ To be used inside your replicated step function, such as in a
+ `tf.distribute.StrategyExtended.call_for_each_replica` call.
+ """
+
+ def __init__(self, strategy, replica_id_in_sync_group):
+ self._distribution_strategy = strategy
+ self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
+ self)
+ self._replica_id_in_sync_group = replica_id_in_sync_group
+
+ def __enter__(self):
+ _push_per_thread_mode(self._thread_context)
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ _pop_per_thread_mode()
+
+ def merge_call(self, merge_fn, args=(), kwargs=None):
+ """Merge args across replicas and run `merge_fn` in a cross-replica context.
+
+ This allows communication and coordination when there are multiple calls
+ to a model function triggered by a call to
+ `strategy.extended.call_for_each_replica(model_fn, ...)`.
+
+ See `tf.distribute.StrategyExtended.call_for_each_replica` for an
+ explanation.
+
+ If not inside a distributed scope, this is equivalent to:
+
+ ```
+ strategy = tf.distribute.get_strategy()
+ with cross-replica-context(strategy):
+ return merge_fn(strategy, *args, **kwargs)
+ ```
+
+ Args:
+ merge_fn: function that joins arguments from threads that are given as
+ PerReplica. It accepts `tf.distribute.Strategy` object as
+ the first argument.
+ args: List or tuple with positional per-thread arguments for `merge_fn`.
+ kwargs: Dict with keyword per-thread arguments for `merge_fn`.
+
+ Returns:
+ The return value of `merge_fn`, except for `PerReplica` values which are
+ unpacked.
+ """
+ require_replica_context(self)
+ if kwargs is None:
+ kwargs = {}
+ return self._merge_call(merge_fn, args, kwargs)
+
+ def _merge_call(self, merge_fn, args, kwargs):
+ """Default implementation for single replica."""
+ _push_per_thread_mode( # thread-local, so not needed with multiple threads
+ distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
+ self._distribution_strategy))
+ try:
+ return merge_fn(self._distribution_strategy, *args, **kwargs)
+ finally:
+ _pop_per_thread_mode()
+
+ @property
+ def num_replicas_in_sync(self):
+ """Returns number of replicas over which gradients are aggregated."""
+ return self._distribution_strategy.num_replicas_in_sync
+
+ @property
+ def replica_id_in_sync_group(self):
+ """Which replica is being defined, from 0 to `num_replicas_in_sync - 1`."""
+ require_replica_context(self)
+ return self._replica_id_in_sync_group
+
+ @property
+ @doc_controls.do_not_generate_docs # DEPRECATED, use `strategy`
+ def distribution_strategy(self):
+ """DEPRECATED: use `self.stratgey` instead."""
+ return self._distribution_strategy
+
+ @property
+ def strategy(self):
+ """The current `tf.distribute.Strategy` object."""
+ return self._distribution_strategy
+
+ @property
+ def devices(self):
+ """The devices this replica is to be executed on, as a list of strings."""
+ require_replica_context(self)
+ return [device_util.current()]
+
+ # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
+ # all-reduce. It would return a function returning the result of reducing `t`
+ # across all replicas. The caller would wait to call this function until they
+ # needed the reduce result, allowing an efficient implementation:
+ # * With eager execution, the reduction could be performed asynchronously
+ # in the background, not blocking until the result was needed.
+ # * When constructing a graph, it could batch up all reduction requests up
+ # to that point that the first result is needed. Most likely this can be
+ # implemented in terms of `merge_call()` and `batch_reduce_to()`.
+
+# ------------------------------------------------------------------------------
+
+
+class _DefaultDistributionStrategy(DistributionStrategy):
+ """Default `tf.distribute.Strategy` if none is explicitly selected."""
+
+ def __init__(self):
+ super(_DefaultDistributionStrategy, self).__init__(
+ _DefaultDistributionExtended(self))
+
+
+class _DefaultDistributionExtended(DistributionStrategyExtended):
+ """Implementation of _DefaultDistributionStrategy."""
+
+ def _scope(self, strategy):
+ """Context manager setting a variable creator and `self` as current."""
+ if distribution_strategy_context.has_distribution_strategy():
+ raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
+
+ def creator(next_creator, *args, **kwargs):
+ _require_distribution_strategy_scope_strategy(strategy)
+ return next_creator(*args, **kwargs)
+
+ return _CurrentDistributionContext(
+ strategy, variable_scope.variable_creator_scope(creator))
+
+ def colocate_vars_with(self, colocate_with_variable):
+ """Does not require `self.scope`."""
+ _require_distribution_strategy_scope_extended(self)
+ return ops.colocate_with(colocate_with_variable)
+
+ def _distribute_dataset(self, dataset_fn):
+ return self._call_dataset_fn(dataset_fn)
+
+ def _make_dataset_iterator(self, dataset):
+ return _DefaultDistributionExtended.DefaultInputIterator(dataset)
+
+ def _make_input_fn_iterator(self,
+ input_fn,
+ replication_mode=InputReplicationMode.PER_WORKER):
+ return input_fn(InputContext()).make_initializable_iterator()
+
+ def _broadcast_to(self, tensor, destinations):
+ if destinations is None:
+ return tensor
+ else:
+ raise NotImplementedError("TODO")
+
+ def _call_for_each_replica(self, fn, args, kwargs):
+ with ReplicaContext(
+ self._container_strategy(),
+ replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
+ return fn(*args, **kwargs)
+
+ def _reduce_to(self, reduce_op, value, destinations):
+ # TODO(josh11b): Use destinations?
+ del reduce_op, destinations
+ return value
+
+ def _update(self, var, fn, args, kwargs, group):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
+
+ def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
+ # TODO(josh11b): Figure out what we should be passing to UpdateContext()
+ # once that value is used for something.
+ with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
+
+ def read_var(self, replica_local_var):
+ return array_ops.identity(replica_local_var)
+
+ def _unwrap(self, distributed_value):
+ return [distributed_value]
+
+ def value_container(self, value):
+ return value
+
+ @property
+ def _num_replicas_in_sync(self):
+ return 1
+
+ @property
+ def worker_devices(self):
+ raise RuntimeError("worker_devices() method unsupported by default "
+ "tf.distribute.Strategy.")
+
+ @property
+ def parameter_devices(self):
+ raise RuntimeError("parameter_devices() method unsupported by default "
+ "tf.distribute.Strategy.")
+
+ def non_slot_devices(self, var_list):
+ return min(var_list, key=lambda x: x.name)
+
+ # TODO(priyag): This should inherit from `InputIterator`, once dependency
+ # issues have been resolved.
+ class DefaultInputIterator(object):
+ """Default implementation of `InputIterator` for default strategy."""
+
+ def __init__(self, dataset):
+ self._dataset = dataset
+ if eager_context.executing_eagerly():
+ self._iterator = dataset.make_one_shot_iterator()
+ else:
+ self._iterator = dataset.make_initializable_iterator()
+
+ def get_next(self):
+ return self._iterator.get_next()
+
+ def initialize(self):
+ if eager_context.executing_eagerly():
+ self._iterator = self._dataset.make_one_shot_iterator()
+ return []
+ else:
+ return [self._iterator.initializer]
+
+ # TODO(priyag): Delete this once all strategies use global batch size.
+ @property
+ def _global_batch_size(self):
+ return True
+
+
+# ------------------------------------------------------------------------------
+# We haven't yet implemented deserialization for DistributedVariables.
+# So here we catch any attempts to deserialize variables
+# when using distribution strategies.
+# pylint: disable=protected-access
+_original_from_proto = resource_variable_ops._from_proto_fn
+
+
+def _from_proto_fn(v, import_scope=None):
+ if distribution_strategy_context.has_distribution_strategy():
+ raise NotImplementedError(
+ "Deserialization of variables is not yet supported when using a "
+ "tf.distribute.Strategy.")
+ else:
+ return _original_from_proto(v, import_scope=import_scope)
+
+resource_variable_ops._from_proto_fn = _from_proto_fn
+# pylint: enable=protected-access
+
+
+#-------------------------------------------------------------------------------
+# Shorthand for some methods from distribution_strategy_context.
+_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access
+_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access
+_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access
diff --git a/tensorflow/python/training/distribute_test.py b/tensorflow/python/distribute/distribute_lib_test.py
similarity index 96%
rename from tensorflow/python/training/distribute_test.py
rename to tensorflow/python/distribute/distribute_lib_test.py
index ad4d50c..d63d1fe 100644
--- a/tensorflow/python/training/distribute_test.py
+++ b/tensorflow/python/distribute/distribute_lib_test.py
@@ -18,12 +18,12 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import distribution_strategy_context
class _TestReplicaContext(distribute_lib.ReplicaContext):
@@ -92,9 +92,9 @@
variable_scope.variable(1.0, name="bar"))
with self.assertRaises(RuntimeError):
- dist.call_for_each_replica(run_fn)
+ dist.extended.call_for_each_replica(run_fn)
with dist.scope():
- dist.call_for_each_replica(run_fn)
+ dist.extended.call_for_each_replica(run_fn)
_assert_in_default_state(self)
def testScope(self):
diff --git a/tensorflow/python/distribute/distribution_strategy_context.py b/tensorflow/python/distribute/distribution_strategy_context.py
new file mode 100644
index 0000000..78e096e
--- /dev/null
+++ b/tensorflow/python/distribute/distribution_strategy_context.py
@@ -0,0 +1,236 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to get distribution strategy related contexts."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.util.lazy_loader import LazyLoader
+from tensorflow.python.util.tf_export import tf_export
+
+
+# There is a circular dependency between this and `distribute` module. So we
+# load it lazily to workaround this.
+distribute_lib = LazyLoader(
+ "distribute_lib", globals(),
+ "tensorflow.python.distribute.distribute_lib")
+
+# ------------------------------------------------------------------------------
+# Internal API for setting the current thread mode as being either in a
+# replica or cross-replica context for a particular distribution strategy.
+
+
+class _ThreadMode(object):
+
+ def __init__(self, dist, cross, replica):
+ self.distribution_strategy = dist
+ self.cross_replica_context = cross
+ self.replica_context = replica
+
+
+class _CrossReplicaThreadMode(_ThreadMode):
+
+ def __init__(self, distribution_strategy):
+ _ThreadMode.__init__(
+ self, distribution_strategy, distribution_strategy, None)
+
+
+class _InReplicaThreadMode(_ThreadMode):
+
+ def __init__(self, replica_ctx):
+ _ThreadMode.__init__(
+ self, replica_ctx.distribution_strategy, None, replica_ctx)
+
+
+def _push_per_thread_mode(context):
+ ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access
+
+
+def _pop_per_thread_mode():
+ ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access
+
+
+class _DefaultReplicaThreadMode(_ThreadMode):
+ """Type of default value returned by `_get_per_thread_mode()`.
+
+ Used when the thread-local stack is empty.
+ """
+
+ def __init__(self):
+ _ThreadMode.__init__(self, _get_default_distribution_strategy(), None,
+ _get_default_replica_context())
+
+
+def _get_per_thread_mode():
+ try:
+ return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access
+ except (AttributeError, IndexError):
+ return _get_default_replica_mode()
+
+
+# ------------------------------------------------------------------------------
+# Public API for accessing the current thread mode
+
+
+@tf_export("distribute.get_replica_context")
+def get_replica_context():
+ """Returns the current `tf.distribute.ReplicaContext` or `None`.
+
+ Returns `None` if in a cross-replica context.
+
+ Note that execution:
+
+ 1. starts in the default (single-replica) replica context (this function
+ will return the default `ReplicaContext` object);
+ 2. switches to cross-replica context (in which case this will return
+ `None`) when entering a `with tf.distribute.Strategy.scope():` block;
+ 3. switches to a (non-default) replica context inside
+ `extended.call_for_each_replica(fn, ...)`;
+ 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
+ inside `merge_fn` you are back in the cross-replica context (and again
+ this function will return `None`).
+
+ Note that you can also go directly from step 1 to 4 to switch to a
+ cross-replica context for the default `tf.distribute.Strategy`. You may
+ also switch from the cross-replica context of 4 to a replica context by
+ calling `extended.call_for_each_replica()`, jumping back to step 3.
+
+ Most `tf.distribute.Strategy` methods may only be executed in
+ a cross-replica context, in a replica context you should use the
+ `ReplicaContext` API instead.
+
+ Returns:
+ The current `ReplicaContext` object when in a replica context scope,
+ else `None`.
+
+ Within a particular block, exactly one of these two things will be true:
+
+ * `get_replica_context()` returns non-`None`, or
+ * `tf.distribute.is_cross_replica_context()` returns True.
+ """
+ return _get_per_thread_mode().replica_context
+
+
+def get_cross_replica_context():
+ """Returns the current tf.distribute.Strategy if in a cross-replica context.
+
+ DEPRECATED: Please use `in_cross_replica_context()` and
+ `get_distribution_strategy()` instead.
+
+ Note that execution:
+
+ 1. starts in the default (single-replica) replica context;
+ 2. switches to cross-replica context when entering a
+ `with tf.distribute.Strategy.scope():` block;
+ 3. switches to a (non-default) replica context inside
+ `call_for_each_replica(fn, ...)`;
+ 4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
+ inside `merge_fn` you are back in the cross-replica context.
+
+ Note that you can also go directly from step 1 to 4 to switch to a
+ cross-replica context for the default `tf.distribute.Strategy`. You may
+ also switch from the cross-replica context of 4 to a replica context by
+ calling `call_for_each_replica()`, jumping back to step 3.
+
+ Most `tf.distribute.Strategy` methods may only be executed in
+ a cross-replica context.
+
+ Returns:
+ Returns the current `tf.distribute.Strategy` object in a cross-replica
+ context, or `None`.
+
+ Exactly one of `get_replica_context()` and `get_cross_replica_context()`
+ will return `None` in a particular block.
+ """
+ return _get_per_thread_mode().cross_replica_context
+
+
+@tf_export("distribute.in_cross_replica_context")
+def in_cross_replica_context():
+ """Returns True if in a cross-replica context.
+
+ See `tf.distribute.get_replica_context` for details.
+
+ Returns:
+ True if in a cross-replica context (`get_replica_context()` returns
+ `None`), or False if in a replica context (`get_replica_context()` returns
+ non-`None`).
+ """
+ return _get_per_thread_mode().cross_replica_context is not None
+
+
+@tf_export("distribute.get_strategy")
+def get_distribution_strategy():
+ """Returns the current `tf.distribute.Strategy` object.
+
+ Typically only used in a cross-replica context:
+
+ ```
+ if tf.distribute.in_cross_replica_context():
+ strategy = tf.distribute.get_strategy()
+ ...
+ ```
+
+ Returns:
+ A `tf.distribute.Strategy` object. Inside a
+ `with distribution_strategy.scope()` block, it returns
+ `distribution_strategy`, otherwise it returns the default
+ (single-replica) `tf.distribute.Strategy` object.
+ """
+ return _get_per_thread_mode().distribution_strategy
+
+
+@tf_export("distribute.has_strategy")
+def has_distribution_strategy():
+ """Return if there is a current non-default `tf.distribute.Strategy`.
+
+ Returns:
+ True if inside a `with strategy.scope():`.
+ """
+ return get_distribution_strategy() is not _get_default_distribution_strategy()
+
+
+# ------------------------------------------------------------------------------
+# Defaults that are used when no distribution strategy is explicitly created.
+# We create them lazily in a function so that we can workaround the circular
+# dependency on distribute_lib. See lazy loader at the top of this file.
+
+_defaults = {
+ "distribution_strategy": None,
+ "replica_context": None,
+ "replica_mode": None
+}
+
+
+def _get_default_distribution_strategy():
+ if _defaults["distribution_strategy"] is None:
+ _defaults["distribution_strategy"] = (
+ distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access
+ return _defaults["distribution_strategy"]
+
+
+def _get_default_replica_context():
+ if _defaults["replica_context"] is None:
+ _defaults["replica_context"] = distribute_lib.ReplicaContext(
+ _get_default_distribution_strategy(), replica_id_in_sync_group=0)
+ return _defaults["replica_context"]
+
+
+def _get_default_replica_mode():
+ if _defaults["replica_mode"] is None:
+ _defaults["replica_mode"] = _DefaultReplicaThreadMode()
+ return _defaults["replica_mode"]
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index 227b00f..549fa8f 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -308,7 +308,7 @@
raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
'`estimator.train`')
- if estimator._config._train_distribute.between_graph:
+ if estimator._config._train_distribute.extended.experimental_between_graph:
# TODO(yuefengz): remove this limitation once we figure out how to merge
# return values from `_worker_fn`s.
raise ValueError('`Estimator.train` API is not supported for %s with '
@@ -356,7 +356,7 @@
raise ValueError('Only `STANDALONE_CLIENT` mode is supported when you call '
'`Estimator.train`')
- if estimator._config._eval_distribute.between_graph:
+ if estimator._config._eval_distribute.extended.experimental_between_graph:
# TODO(yuefengz): remove this limitation once we figure out how to merge
# return values from `_worker_fn`s.
raise ValueError('`Estimator.evaluate` API is not supported for %s with '
diff --git a/tensorflow/python/distribute/input_ops_test.py b/tensorflow/python/distribute/input_ops_test.py
index 54f7c5d..2689dbb 100644
--- a/tensorflow/python/distribute/input_ops_test.py
+++ b/tensorflow/python/distribute/input_ops_test.py
@@ -94,7 +94,7 @@
for r in range(self._num_records):
self.assertAllEqual(record_fn(r, f), self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testTFRecordDataset(self):
dataset = readers.TFRecordDataset(self._createTFRecordFiles())
@@ -138,10 +138,10 @@
actual, expected = [], []
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
- actual.append(sess.run(next_element))
+ actual.append(self.evaluate(next_element))
expected.append(self._record(r, f))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
self.assertAllEqual(expected, actual)
def testComplexPipeline(self):
@@ -171,9 +171,9 @@
num_iterations = (self._num_files * self._num_records * num_epochs) // (
self._num_shards * batch_size)
for _ in range(num_iterations):
- actual.extend(sess.run(next_element))
+ actual.extend(self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
expected = []
for f in range(0, self._num_files, self._num_shards):
@@ -211,7 +211,7 @@
self.assertAllEqual(
self._text_line(r, f), self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
+ self.evaluate(next_element)
def testTextLineReader(self):
dataset = readers.TextLineDataset(self._createTextFiles())
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
new file mode 100644
index 0000000..d6d40df
--- /dev/null
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -0,0 +1,813 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class MirroredStrategy implementing DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import copy
+import functools
+import threading
+
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.distribute import reduce_util
+from tensorflow.python.distribute import shared_variable_creator
+from tensorflow.python.distribute import values
+from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import device as tf_device
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import coordinator
+from tensorflow.python.util import nest
+
+
+# TODO(josh11b): Replace asserts in this file with if ...: raise ...
+
+
+@contextlib.contextmanager
+def _enter_graph(g):
+ if context.executing_eagerly():
+ with g.as_default(), context.eager_mode():
+ yield
+ else:
+ with g.as_default():
+ yield
+
+
+def _cpu_device(device):
+ cpu_device = tf_device.DeviceSpec.from_string(device)
+ cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0))
+ return cpu_device.to_string()
+
+
+class _RequestedStop(Exception): # pylint: disable=g-bad-exception-name
+ pass
+
+
+# _call_for_each_replica and _reduce_non_distributed_value are not members of
+# MirroredStrategy so that they are generally not allowed to use anything
+# specific to MirroredStrategy and thus can be shared with other distribution
+# strategies.
+
+
+# TODO(yuefengz): maybe create a common class for those who need to call this
+# _call_for_each_replica.
+def _call_for_each_replica(distribution, fn, args, kwargs):
+ """Run `fn` in separate threads, once per replica/worker device.
+
+ Args:
+ distribution: the DistributionStrategy object.
+ fn: function to run (will be run once per device, each in its own thread).
+ args: positional arguments for `fn`
+ kwargs: keyword arguments for `fn`.
+
+ Returns:
+ Merged return value of `fn` across all replicas.
+
+ Raises:
+ RuntimeError: If fn() calls get_replica_context().merge_call() a different
+ number of times from the available devices.
+ """
+ # TODO(josh11b): Add this option once we add synchronization to variable
+ # creation. Until then, this is pretty unsafe to use.
+ run_concurrently = False
+ if not context.executing_eagerly():
+ # Needed for per-thread device, etc. contexts in graph mode.
+ ops.get_default_graph().switch_to_thread_local()
+
+ coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
+
+ shared_variable_store = {}
+
+ # TODO(isaprykin): Create these threads once instead of during every run()
+ # call.
+ threads = []
+ for index, d in enumerate(distribution.extended.worker_devices):
+ variable_creator_fn = shared_variable_creator.make_fn(
+ shared_variable_store, index)
+ t = MirroredExtended._MirroredReplicaThread( # pylint: disable=protected-access
+ distribution, coord, d, variable_creator_fn, fn,
+ *values.select_device(d, args), **values.select_device(d, kwargs))
+ threads.append(t)
+
+ for t in threads:
+ t.start()
+
+ # When `fn` starts `should_run` event is set on _MirroredReplicaThread
+ # (`MRT`) threads. The execution waits until
+ # `MRT.has_paused` is set, which indicates that either `fn` is
+ # complete or a `get_replica_context().merge_call()` is called. If `fn` is
+ # complete, then `MRT.done` is set to True. Otherwise, arguments
+ # of `get_replica_context().merge_call` from all paused threads are grouped
+ # and the `merge_fn` is performed. Results of the
+ # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
+ # Each such `get_replica_context().merge_call` call returns the
+ # `MRT.merge_result` for that thread when `MRT.should_run` event
+ # is reset again. Execution of `fn` resumes.
+
+ try:
+ with coord.stop_on_exception():
+ all_done = False
+ while not all_done and not coord.should_stop():
+ done = []
+ if run_concurrently:
+ for t in threads:
+ t.should_run.set()
+ for t in threads:
+ t.has_paused.wait()
+ t.has_paused.clear()
+ if coord.should_stop():
+ return None
+ done.append(t.done)
+ else:
+ for t in threads:
+ t.should_run.set()
+ t.has_paused.wait()
+ t.has_paused.clear()
+ if coord.should_stop():
+ return None
+ done.append(t.done)
+ if coord.should_stop():
+ return None
+ all_done = all(done)
+ if not all_done:
+ if any(done):
+ raise RuntimeError("Some replicas made a different number of "
+ "replica_context().merge_call() calls.")
+ # get_replica_context().merge_call() case
+ merge_args = values.regroup({t.device: t.merge_args for t in threads})
+ merge_kwargs = values.regroup(
+ {t.device: t.merge_kwargs for t in threads})
+ # We capture the name_scope of the MRT when we call merge_fn
+ # to ensure that if we have opened a name scope in the MRT,
+ # it will be respected when executing the merge function. We only
+ # capture the name_scope from the first MRT and assume it is
+ # the same for all other MRTs.
+ mtt_captured_name_scope = threads[0].captured_name_scope
+ with ops.name_scope(mtt_captured_name_scope):
+ merge_result = threads[0].merge_fn(distribution, *merge_args,
+ **merge_kwargs)
+ for t in threads:
+ t.merge_result = values.select_device(t.device, merge_result)
+ finally:
+ for t in threads:
+ t.should_run.set()
+ coord.join(threads)
+
+ return values.regroup({t.device: t.main_result for t in threads})
+
+
+def _reduce_non_distributed_value(extended, reduce_op, value, destinations):
+ """Reduce a non-DistributedValue `value` to `destinations`."""
+ if isinstance(value, values.DistributedValues):
+ raise ValueError("You are passing a `DistributedValue` to "
+ "`_reduce_non_distributed_value`, which is not allowed.")
+
+ # If the same value is present on all replicas then the PerReplica value will
+ # be a single value. We also handle the case when `value` is a single value
+ # and equal to 0.
+ if value == 0:
+ return 0
+ # If there is only a single value and the reduce op is MEAN,
+ # that value should be on all destinations.
+ if reduce_op == reduce_util.ReduceOp.MEAN:
+ return value
+
+ cross_device_ops_lib.validate_destinations(destinations)
+ # We do not support a reduce op of SUM if the value is the same across
+ # all replicas. We call this as part of assign functions for MirroredVariables
+ # and summing up identical values across replicas is not clearly defined.
+ if (len(extended.worker_devices) != 1 or
+ not cross_device_ops_lib.check_destinations(destinations)):
+ raise ValueError("A non-DistributedValues value %s cannot be reduced with "
+ "the given reduce op %s." % (value, reduce_op))
+ # TODO(anjalisridhar): Moves these methods to a device utility file?
+ devices = cross_device_ops_lib.get_devices_from(destinations)
+ if len(devices) == 1:
+ with ops.device(devices[0]):
+ return array_ops.identity(value)
+ else:
+ value_updates = {}
+ for d in devices:
+ with ops.device(d):
+ value_updates[d] = array_ops.identity(value)
+ return values.Mirrored(value_updates)
+
+
+def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # pylint: disable=g-missing-docstring
+ # Figure out what collections this variable should be added to.
+ # We'll add the MirroredVariable to those collections instead.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # Get synchronization value
+ synchronization = kwargs.get("synchronization",
+ variable_scope.VariableSynchronization.ON_WRITE)
+ if synchronization == variable_scope.VariableSynchronization.NONE:
+ raise ValueError("`NONE` variable synchronization mode is not "
+ "supported with `Mirrored` distribution strategy. Please"
+ " change the `synchronization` for variable: " +
+ kwargs["name"])
+ elif synchronization == variable_scope.VariableSynchronization.ON_READ:
+ # Variables that are to be synced on read are replica local.
+ is_replica_local = True
+ kwargs["trainable"] = False
+ elif (synchronization == variable_scope.VariableSynchronization.ON_WRITE or
+ synchronization == variable_scope.VariableSynchronization.AUTO):
+ # `AUTO` synchronization for `MirroredStrategy` is `ON_WRITE`.
+ is_replica_local = False
+ else:
+ raise ValueError("Invalid variable synchronization mode: " +
+ synchronization + " for variable: " + kwargs["name"])
+
+ # Get aggregation value
+ aggregation = kwargs.pop("aggregation",
+ variable_scope.VariableAggregation.NONE)
+ if aggregation not in (
+ variable_scope.VariableAggregation.NONE,
+ variable_scope.VariableAggregation.SUM,
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_REPLICA
+ ):
+ raise ValueError("Invalid variable aggregation mode: " + aggregation +
+ " for variable: " + kwargs["name"])
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
+ index = real_mirrored_creator(devices, *args, **kwargs)
+
+ if is_replica_local:
+ result = values.ReplicaLocalVariable(
+ index, index[devices[0]], aggregation)
+ else:
+ result = values.MirroredVariable(index, index[devices[0]], aggregation)
+
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ if v in l:
+ l.remove(v)
+ g.add_to_collections(collections, result)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
+ return result
+
+
+class MirroredStrategy(distribute_lib.DistributionStrategy):
+ """Mirrors vars to distribute across multiple devices and machines.
+
+ This strategy uses one replica per device and sync replication for its
+ multi-GPU version.
+
+ The multi-worker version will be added in the fture.
+
+ Args:
+ devices: a list of device strings.
+ num_gpus_per_worker: number of GPUs per worker.
+ cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not
+ set, nccl will be use by default.
+ """
+
+ def __init__(self,
+ devices=None,
+ num_gpus_per_worker=None,
+ cross_device_ops=None):
+ extended = MirroredExtended(self, devices, num_gpus_per_worker,
+ cross_device_ops)
+ super(MirroredStrategy, self).__init__(extended)
+
+
+class MirroredExtended(distribute_lib.DistributionStrategyExtended):
+ """Implementation of MirroredStrategy."""
+
+ def __init__(self,
+ container_strategy,
+ devices=None,
+ num_gpus_per_worker=None,
+ cross_device_ops=None):
+ super(MirroredExtended, self).__init__(container_strategy)
+ self._cross_device_ops = cross_device_ops
+ # Remember num GPUs which might be needed by `configure` method.
+ self._num_gpus = num_gpus_per_worker
+
+ self._initialize_local(self._num_gpus, devices)
+
+ def _initialize_local(self, num_gpus, devices):
+ """Initializes the object for local training."""
+ self._cluster_spec = None
+ # Convert `num_gpus` into `devices`, shouldn't specify both.
+ if devices is None:
+ if num_gpus is None:
+ num_gpus = context.num_gpus()
+ if num_gpus == 0:
+ devices = ["/device:CPU:0"]
+ else:
+ devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
+ elif num_gpus is not None:
+ raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+ self._num_gpus = num_gpus
+ # TODO(yuefengz): consider setting the default device.
+
+ assert devices, "Must specify at least one device."
+ assert len(set(devices)) == len(devices), (
+ "No duplicates allowed in `devices` argument.")
+ # TODO(josh11b): Require at least 2 devices?
+ self._devices = [device_util.resolve(d) for d in devices]
+ self._canonical_device_set = set(self._devices)
+ self._device_index = values.PerReplica(
+ {d: i for i, d in enumerate(devices)})
+
+ def _initialize_multi_worker(self, num_gpus, cluster_spec):
+ """Initializes the object for multi-worker training."""
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._cluster_spec = cluster_spec
+
+ self._workers = []
+ for job in ["chief", "worker"]:
+ for task in range(len(cluster_spec.as_dict().get(job, []))):
+ self._workers.append("/job:%s/task:%d" % (job, task))
+
+ if num_gpus is None:
+ raise ValueError("`num_gpus` is required if `cluster_spec` is given.")
+ if num_gpus > 0:
+ self._worker_devices = [
+ (worker, [
+ device_util.canonicalize(worker + "/device:GPU:%d" % gpu)
+ for gpu in range(num_gpus)
+ ]) for worker in self._workers
+ ]
+ else:
+ self._worker_devices = [
+ (worker, [device_util.canonicalize(worker, "/device:CPU:0")])
+ for worker in self._workers
+ ]
+
+ devices = nest.flatten([l for _, l in self._worker_devices])
+
+ # Setting `_default_device` will add a device scope in the
+ # distribution.scope. We set the default device to the first worker. When
+ # users specify device under distribution.scope by
+ # with tf.device("/cpu:0"):
+ # ...
+ # their ops will end up on the cpu device of its first worker, e.g.
+ # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode.
+ self._default_device = self._workers[0]
+
+ assert devices, "Must specify at least one device."
+ assert len(set(devices)) == len(devices), (
+ "No duplicates allowed in `devices` argument.")
+ # TODO(josh11b): Require at least 2 devices?
+ self._devices = [device_util.resolve(d) for d in devices]
+ self._canonical_device_set = set(self._devices)
+ self._device_index = values.PerReplica(
+ {d: i for i, d in enumerate(devices)})
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ """Create a mirrored variable. See `DistributionStrategy.scope`."""
+ colocate_with = kwargs.pop("colocate_with", None)
+ devices = self._get_devices_from(colocate_with)
+
+ def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
+ index = {}
+ for i, d in enumerate(devices):
+ with ops.device(d):
+ if i > 0:
+ # Give replicas meaningful distinct names:
+ var0name = index[devices[0]].name.split(":")[0]
+ # We append a / to variable names created on replicas with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
+ # Initialize replicas with the same value:
+ def initial_value_fn(device=d):
+ if context.executing_eagerly():
+ init_value = index[devices[0]].value()
+ return array_ops.identity(init_value)
+ else:
+ with ops.device(device):
+ init_value = index[devices[0]].initial_value
+ return array_ops.identity(init_value)
+ kwargs["initial_value"] = initial_value_fn
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ # Don't record operations (e.g. other variable reads) during
+ # variable creation.
+ with tape.stop_recording():
+ v = next_creator(*args, **kwargs)
+ assert not isinstance(v, values.DistributedVariable)
+ index[d] = v
+ return index
+
+ return _create_mirrored_variable(devices, _real_mirrored_creator, *args,
+ **kwargs)
+
+ def _distribute_dataset(self, dataset_fn):
+ if self._cluster_spec:
+ return values.MultiWorkerDataset(
+ functools.partial(self._call_dataset_fn, dataset_fn),
+ self._worker_devices,
+ auto_shard=False)
+ else:
+ return values.PerReplicaDataset(
+ self._call_dataset_fn(dataset_fn), self._devices)
+
+ def _make_dataset_iterator(self, dataset):
+ if self._cluster_spec:
+ worker_device_pairs = self._worker_devices
+ else:
+ worker = device_util.canonicalize("/device:CPU:0")
+ worker_device_pairs = [(worker, self._devices)]
+ return values.DatasetIterator(dataset, worker_device_pairs,
+ self._num_replicas_in_sync)
+
+ def _make_input_fn_iterator(
+ self,
+ input_fn,
+ replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
+ input_contexts = []
+ if self._cluster_spec:
+ num_workers = len(self._worker_devices)
+ worker_device_pairs = self._worker_devices
+ else:
+ num_workers = 1
+ worker = device_util.canonicalize("/device:CPU:0")
+ worker_device_pairs = [(worker, self._devices)]
+ for i in range(num_workers):
+ input_contexts.append(distribute_lib.InputContext(
+ num_input_pipelines=num_workers,
+ input_pipeline_id=i,
+ num_replicas_in_sync=self._num_replicas_in_sync))
+ return values.InputFunctionIterator(
+ input_fn, worker_device_pairs, input_contexts)
+
+ # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
+ def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
+ initial_loop_values=None):
+ if initial_loop_values is None:
+ initial_loop_values = {}
+ initial_loop_values = nest.flatten(initial_loop_values)
+
+ ctx = values.MultiStepContext()
+ def body(i, *args):
+ """A wrapper around `fn` to create the while loop body."""
+ del args
+ fn_inputs = iterator.get_next()
+ if not isinstance(fn_inputs, tuple):
+ fn_inputs = (fn_inputs,)
+ fn_result = fn(ctx, fn_inputs)
+ for (name, output) in ctx.last_step_outputs.items():
+ # Convert all outputs to tensors, potentially from `DistributedValues`.
+ ctx.last_step_outputs[name] = self._unwrap(output)
+ flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
+ with ops.control_dependencies([fn_result]):
+ return [i + 1] + flat_last_step_outputs
+
+ # We capture the control_flow_context at this point, before we run `fn`
+ # inside a while_loop. This is useful in cases where we might need to exit
+ # these contexts and get back to the outer context to do some things, for
+ # e.g. create an op which should be evaluated only once at the end of the
+ # loop on the host. One such usage is in creating metrics' value op.
+ self._outer_control_flow_context = (
+ ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
+
+ cond = lambda i, *args: i < iterations
+ i = constant_op.constant(0)
+ loop_result = control_flow_ops.while_loop(
+ cond, body, [i] + initial_loop_values, name="",
+ parallel_iterations=1, back_prop=False, swap_memory=False,
+ return_same_structure=True)
+ del self._outer_control_flow_context
+
+ ctx.run_op = control_flow_ops.group(loop_result)
+
+ # Convert the last_step_outputs from a list to the original dict structure
+ # of last_step_outputs.
+ last_step_tensor_outputs = loop_result[1:]
+ last_step_tensor_outputs_dict = nest.pack_sequence_as(
+ ctx.last_step_outputs, last_step_tensor_outputs)
+
+ for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
+ output = last_step_tensor_outputs_dict[name]
+ # For outputs that have already been reduced, wrap them in a Mirrored
+ # container, else in a PerReplica container.
+ if reduce_op is None:
+ last_step_tensor_outputs_dict[name] = values.regroup(
+ {d: t for d, t in zip(self._devices, output)}, values.PerReplica)
+ else:
+ assert len(output) == 1
+ last_step_tensor_outputs_dict[name] = output[0]
+
+ ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
+ return ctx
+
+ def _broadcast_to(self, tensor, destinations):
+ # This is both a fast path for Python constants, and a way to delay
+ # converting Python values to a tensor until we know what type it
+ # should be converted to. Otherwise we have trouble with:
+ # global_step.assign_add(1)
+ # since the `1` gets broadcast as an int32 but global_step is int64.
+ if isinstance(tensor, (float, int)):
+ return tensor
+ # TODO(josh11b): In eager mode, use one thread per device, or async mode.
+ return self._get_cross_device_ops().broadcast(
+ tensor, destinations or self._devices)
+
+ def _call_for_each_replica(self, fn, args, kwargs):
+ return _call_for_each_replica(self._container_strategy(), fn, args, kwargs)
+
+ def _configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del task_type, task_id
+
+ if session_config:
+ session_config.CopyFrom(self._update_config_proto(session_config))
+
+ if cluster_spec:
+ self._initialize_multi_worker(self._num_gpus, cluster_spec)
+
+ if self._cross_device_ops is None:
+ if self._cluster_spec:
+ # It currently cannot detect the toplogy of remote workers. So we
+ # hard-code the multi-worker all-reduce algorithm for now.
+ if len(self._workers) == 1:
+ # The default is "nccl".
+ self._cross_device_ops = (
+ cross_device_ops_lib.AllReduceCrossDeviceOps())
+ else:
+ # The default is hierarchical reduce and broadcast.
+ self._cross_device_ops = cross_device_ops_lib.MultiWorkerAllReduce(
+ self._workers, self._num_gpus)
+ else:
+ self._cross_device_ops = cross_device_ops_lib.choose_the_best(
+ self._devices, session_config=session_config)
+
+ def _update_config_proto(self, config_proto):
+ updated_config = copy.deepcopy(config_proto)
+ updated_config.isolate_session_state = True
+ return updated_config
+
+ def _get_cross_device_ops(self):
+ if self._cross_device_ops is None:
+ self._cross_device_ops = (
+ cross_device_ops_lib.ReductionToOneDeviceCrossDeviceOps())
+ return self._cross_device_ops
+
+ def _reduce_to(self, reduce_op, value, destinations):
+ assert not isinstance(value, values.Mirrored)
+ if not isinstance(value, values.DistributedValues):
+ # This function handles reducing values that are not PerReplica or
+ # Mirrored values. For example, the same value could be present on all
+ # replicas in which case `value` would be a single value or value could
+ # be 0.
+ return _reduce_non_distributed_value(self, reduce_op, value,
+ destinations)
+ return self._get_cross_device_ops().reduce(
+ reduce_op, value, destinations=destinations)
+
+ def _batch_reduce_to(self, reduce_op, value_destination_pairs):
+ return self._get_cross_device_ops().batch_reduce(reduce_op,
+ value_destination_pairs)
+
+ def _update(self, var, fn, args, kwargs, group):
+ # TODO(josh11b): In eager mode, use one thread per device.
+ assert isinstance(var, values.DistributedVariable)
+ updates = {}
+ for d, v in var._index.items(): # pylint: disable=protected-access
+ name = "update_%d" % self._device_index.get(d)
+ with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ # If args and kwargs are not mirrored, the value is returned as is.
+ updates[d] = fn(v,
+ *values.select_device_mirrored(d, args),
+ **values.select_device_mirrored(d, kwargs))
+ return values.update_regroup(self, updates, group)
+
+ def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
+ assert isinstance(colocate_with, list)
+ # TODO(josh11b): In eager mode, use one thread per device.
+ updates = {}
+ for d in colocate_with:
+ name = "update_%d" % self._device_index.get(d)
+ with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ updates[d] = fn(*values.select_device_mirrored(d, args),
+ **values.select_device_mirrored(d, kwargs))
+ return values.update_regroup(self, updates, group)
+
+ def read_var(self, replica_local_var):
+ """Read the aggregate value of a replica-local variable."""
+ if isinstance(replica_local_var, values.ReplicaLocalVariable):
+ return replica_local_var._get_cross_replica() # pylint: disable=protected-access
+ assert isinstance(replica_local_var, values.Mirrored)
+ return array_ops.identity(replica_local_var.get())
+
+ def _unwrap(self, val):
+ if isinstance(val, values.DistributedValues):
+ # Return in a deterministic order.
+ if set(val.devices) == self._canonical_device_set:
+ return [val.get(device=d) for d in self._devices]
+ return [val.get(device=d) for d in sorted(val.devices)]
+ return [val]
+
+ def value_container(self, val):
+ return values.value_container(val)
+
+ @property
+ def _num_replicas_in_sync(self):
+ return len(self._devices)
+
+ @property
+ def worker_devices(self):
+ # Make a copy to prevent users from accidentally mutating our copy.
+ return list(self._devices)
+
+ @property
+ def parameter_devices(self):
+ return list(self._devices)
+
+ @property
+ def experimental_between_graph(self):
+ return False
+
+ @property
+ def experimental_should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return True
+
+ @property
+ def should_save_summary(self):
+ return True
+
+ def non_slot_devices(self, var_list):
+ del var_list
+ return list(self._devices)
+
+ def _get_devices_from(self, colocate_with=None):
+ if colocate_with is None:
+ return self._devices
+ else:
+ return cross_device_ops_lib.get_devices_from(colocate_with)
+
+ # TODO(priyag): Delete this once all strategies use global batch size.
+ @property
+ def _global_batch_size(self):
+ return True
+
+ class _MirroredReplicaThread(threading.Thread):
+ """A thread that runs() a function on a device."""
+
+ def __init__(self, dist, coord, device, variable_creator_fn, fn, *args,
+ **kwargs):
+ super(MirroredExtended._MirroredReplicaThread, self).__init__() # pylint: disable=protected-access
+ self.coord = coord
+ self.distribution = dist
+ self.device = device
+ self.replica_id = dist.extended.worker_devices.index(device)
+ self.variable_creator_fn = variable_creator_fn
+ # State needed to run and return the results of `fn`.
+ self.main_fn = fn
+ self.main_args = args
+ self.main_kwargs = kwargs
+ self.main_result = None
+ self.done = False
+ # State needed to run the next merge_call() (if any) requested via
+ # ReplicaContext.
+ self.merge_fn = None
+ self.merge_args = None
+ self.merge_kwargs = None
+ self.merge_result = None
+ self.captured_name_scope = None
+ # We use a thread.Event for the main thread to signal when this
+ # thread should start running (`should_run`), and another for
+ # this thread to transfer control back to the main thread
+ # (`has_paused`, either when it gets to a
+ # `get_replica_context().merge_call` or when `fn` returns). In
+ # either case the event starts cleared, is signaled by calling
+ # set(). The receiving thread waits for the signal by calling
+ # wait() and then immediately clearing the event using clear().
+ self.should_run = threading.Event()
+ self.has_paused = threading.Event()
+ # These fields have to do with inheriting various contexts from the
+ # parent thread:
+ # pylint: disable=protected-access
+ self.context_mode = context.context()._eager_context.mode
+ if not context.context()._context_handle:
+ context.context()._initialize_handle_and_devices()
+ self.context_device_policy = (
+ pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
+ context.context()._context_handle))
+ self.graph = ops.get_default_graph()
+ self._variable_creator_stack = self.graph._variable_creator_stack[:]
+ self._captured_var_scope = variable_scope.get_variable_scope()
+ # Adding a "/" at end lets us re-enter this scope later.
+ self._name_scope = self.graph.get_name_scope()
+ if self._name_scope:
+ self._name_scope += "/"
+ if self.replica_id > 0:
+ if not self._name_scope:
+ self._name_scope = ""
+ self._name_scope += "replica_%d/" % self.replica_id
+
+ def run(self):
+ # pylint: disable=protected-access
+ self.graph._variable_creator_stack = self._variable_creator_stack
+ self.should_run.wait()
+ self.should_run.clear()
+ try:
+ if self.coord.should_stop():
+ return
+ with self.coord.stop_on_exception(), \
+ context.context()._mode(self.context_mode), \
+ context.context().device_policy(self.context_device_policy), \
+ _enter_graph(self.graph), \
+ MirroredReplicaContext(self.distribution, constant_op.constant(
+ self.replica_id, dtypes.int32)), \
+ ops.device(self.device), \
+ ops.name_scope(self._name_scope), \
+ variable_scope.variable_scope(
+ self._captured_var_scope, reuse=self.replica_id > 0), \
+ variable_scope.variable_creator_scope(self.variable_creator_fn):
+ self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
+ self.done = True
+ finally:
+ self.has_paused.set()
+
+
+class MirroredReplicaContext(distribute_lib.ReplicaContext):
+ """ReplicaContext used in MirroredStrategy.call_for_each_replica().
+
+ Opened in `_MirroredReplicaThread`, to allow the user to invoke
+ `MirroredStrategy`'s specific implementation of `merge_call()`,
+ which works by delegating the function and its arguments to
+ the main thread (the one that invoked
+ `MirroredStrategy.call_for_each_replica()`).
+ """
+
+ def _merge_call(self, fn, args, kwargs):
+ """Delegate to the main thread to actually perform merge_call()."""
+ t = threading.current_thread() # a _MirroredReplicaThread
+ t.merge_fn = fn
+ t.merge_args = args
+ t.merge_kwargs = kwargs
+ t.captured_name_scope = t.graph.get_name_scope()
+ # Adding a "/" at end lets us re-enter this scope later.
+ if t.captured_name_scope:
+ t.captured_name_scope += "/"
+ t.has_paused.set()
+ t.should_run.wait()
+ t.should_run.clear()
+ if t.coord.should_stop():
+ raise _RequestedStop()
+ return t.merge_result
+
+ @property
+ def devices(self):
+ distribute_lib.require_replica_context(self)
+ replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
+ return [self._distribution_strategy.extended.worker_devices[replica_id]]
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 33ca27c..7dd1062 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -30,6 +30,9 @@
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
@@ -42,9 +45,6 @@
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.training import device_util
-from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import saver
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
@@ -1433,7 +1433,7 @@
# TODO(priyag): We should probably explicitly specify CPU device on worker.
with ops.device(worker):
result = input_fn(ctx)
- if not isinstance(result, dataset_ops.Dataset):
+ if not isinstance(result, dataset_ops.DatasetV2):
raise ValueError("input_fn must return a tf.data.Dataset.")
iterator = _SingleWorkerDatasetIterator(result, worker, devices)
iterators.append(iterator)
@@ -1608,11 +1608,11 @@
"""A context object that can be used to capture things when running steps.
This context object is useful when running multiple steps at a time using the
- `run_steps_on_dataset` API. For e.g. it allows the user's step function to
- specify which outputs to emit at what frequency. Currently it supports
- capturing output from the last step, as well as capturing non tensor outputs.
- In the future it will be augmented to support other use cases such as output
- each N steps.
+ `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
+ function to specify which outputs to emit at what frequency. Currently it
+ supports capturing output from the last step, as well as capturing non tensor
+ outputs. In the future it will be augmented to support other use cases such
+ as output each N steps.
"""
def __init__(self):
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 55728b1..5a18afa 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -189,9 +189,6 @@
"//tensorflow/python:resource_variable_ops",
],
shard_count = 5,
- tags = [
- "no_windows",
- ],
)
cuda_py_test(
@@ -214,9 +211,6 @@
"//tensorflow/python:resource_variable_ops",
],
shard_count = 15,
- tags = [
- "no_windows",
- ],
)
py_library(
@@ -341,6 +335,7 @@
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:execute",
"//tensorflow/python/eager:tape",
+ "//tensorflow/python/ops/parallel_for:control_flow_ops",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 5b6b421..69d444a 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -20,6 +20,7 @@
import functools
import operator
+import sys
import six
@@ -33,6 +34,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@@ -42,9 +44,20 @@
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
+from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
+# Note that we need to lazy load the following two modules to avoid creating
+# circular dependencies.
+# TODO(b/119775953): fix the circular dependencies.
+pfor_ops = LazyLoader(
+ "pfor_ops", globals(),
+ "tensorflow.python.ops.parallel_for.control_flow_ops")
+
+function = LazyLoader("function", globals(),
+ "tensorflow.python.eager.function")
+
_op_attr_type_cache = {}
@@ -536,11 +549,11 @@
if len(gradients) == 1:
return gradients[0]
- if all([isinstance(g, ops.Tensor) for g in gradients]):
+ if all(isinstance(g, ops.Tensor) for g in gradients):
return gen_math_ops.add_n(gradients)
else:
- assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
- for g in gradients])
+ assert all(isinstance(g, (ops.Tensor, ops.IndexedSlices))
+ for g in gradients)
indexed_slices_list = []
for grad in gradients:
# TODO(xpan): Support nested IndexedSlices and core IndexedSlices
@@ -937,3 +950,203 @@
grad = nest.pack_sequence_as(sources, flat_grad)
return grad
+
+ def jacobian(self,
+ target,
+ sources,
+ unconnected_gradients=UnconnectedGradients.NONE,
+ experimental_use_pfor=True):
+ """Computes the jacobian using operations recorded in context of this tape.
+
+ See http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant for the
+ definition of a Jacobian.
+
+ Example usage:
+
+ with tf.GradientTape() as g:
+ x = tf.constant([1.0, 2.0])
+ g.watch(x)
+ y = x * x
+ jacobian = g.jacobian(y, x)
+ # jacobian value is [[2., 0.], [0., 4.]]
+
+ Args:
+ target: Tensor to be differentiated.
+ sources: a list or nested structure of Tensors or Variables. `target`
+ will be differentiated against elements in `sources`.
+ unconnected_gradients: a value which can either hold 'none' or 'zero' and
+ alters the value which will be returned if the target and sources are
+ unconnected. The possible values and effects are detailed in
+ 'UnconnectedGradients' and it defaults to 'none'.
+ experimental_use_pfor: If true, vectorizes the jacobian computation. Else
+ falls back to a sequential while_loop. Vectorization can sometimes fail
+ or lead to excessive memory usage. This option can be used to disable
+ vectorization in such cases.
+
+ Returns:
+ a list or nested structure of Tensors (or IndexedSlices, or None),
+ one for each element in `sources`. Returned structure is the same as
+ the structure of `sources`.
+
+ Raises:
+ RuntimeError: If called on a non-persistent tape with eager execution
+ enabled and without enabling experimental_use_pfor.
+ ValueError: If vectorization of jacobian computation fails.
+ """
+ flat_sources = nest.flatten(sources)
+ target_static_shape = target.shape
+ target_shape = array_ops.shape(target)
+ # Note that we push and pop the tape here and below. This is needed since we
+ # need gradients through the enclosed operations.
+ self._push_tape()
+ target = array_ops.reshape(target, [-1])
+ self._pop_tape()
+
+ def loop_fn(i):
+ self._push_tape()
+ y = array_ops.gather(target, i)
+ self._pop_tape()
+ return self.gradient(y, flat_sources,
+ unconnected_gradients=unconnected_gradients)
+
+ try:
+ target_size = int(target.shape[0])
+ except TypeError:
+ target_size = array_ops.shape(target)[0]
+
+ if experimental_use_pfor:
+ try:
+ output = pfor_ops.pfor(loop_fn, target_size)
+ except ValueError as err:
+ six.reraise(
+ ValueError,
+ ValueError(
+ str(err) + "\nEncountered an exception while vectorizing the "
+ "jacobian computation. Vectorization can be disabled by setting"
+ " experimental_use_pfor to False."),
+ sys.exc_info()[2])
+ else:
+ if context.executing_eagerly() and not self._persistent:
+ raise RuntimeError(
+ "GradientTape must be created with persistent=True"
+ " to compute the jacobian with eager execution enabled and with "
+ " experimental_use_pfor set to False.")
+ output = pfor_ops.for_loop(
+ loop_fn, [target.dtype] * len(flat_sources), target_size)
+
+ for i, out in enumerate(output):
+ if out is not None:
+ new_shape = array_ops.concat(
+ [target_shape, array_ops.shape(out)[1:]], axis=0)
+ out = array_ops.reshape(out, new_shape)
+ if context.executing_eagerly():
+ out.set_shape(target_static_shape.concatenate(flat_sources[i].shape))
+ output[i] = out
+
+ return nest.pack_sequence_as(sources, output)
+
+ def batch_jacobian(self,
+ target,
+ source,
+ unconnected_gradients=UnconnectedGradients.NONE,
+ experimental_use_pfor=True):
+ """Computes and stacks per-example jacobians.
+
+ See http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant for the
+ definition of a Jacobian. This function is essentially an efficient
+ implementation of the following:
+ `tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`.
+
+ Note that compared to `GradientTape.jacobian` which computes gradient of
+ each output value w.r.t each input value, this function is useful when
+ `target[i,...] is independent of `source[j,...]` for `j != i`. This
+ independence assumption allows more efficient computation as compared to
+ `GradientTape.jacobian`. The output, as well as intermediate activations,
+ are lower dimensional and avoid a bunch of redundant zeros which would
+ result in the jacobian computation given the independence assumption.
+
+ Example usage:
+ with tf.GradientTape() as g:
+ x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32)
+ g.watch(x)
+ y = x * x
+ batch_jacobian = g.batch_jacobian(y, x)
+ # batch_jacobian is [[[2, 0], [0, 4]], [[6, 0], [0, 8]]]
+
+ Args:
+ target: A tensor with rank 2 or higher and with shape [b, y1, ..., y_n].
+ `target[i,...]` should only depend on `source[i,...]`.
+ source: A tensor with rank 2 or higher and with shape [b, x1, ..., x_m].
+ unconnected_gradients: a value which can either hold 'none' or 'zero' and
+ alters the value which will be returned if the target and sources are
+ unconnected. The possible values and effects are detailed in
+ 'UnconnectedGradients' and it defaults to 'none'.
+ experimental_use_pfor: If true, uses pfor for computing the Jacobian. Else
+ uses a tf.while_loop.
+
+ Returns:
+ A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
+ is the jacobian of `target[i, ...]` w.r.t. `source[i, ...]`, i.e. stacked
+ per-example jacobians.
+
+ Raises:
+ RuntimeError: If called on a non-persistent tape with eager execution
+ enabled and without enabling experimental_use_pfor.
+ ValueError: If vectorization of jacobian computation fails or if first
+ dimension of `target` and `source` do not match.
+ """
+ target_shape = target.shape
+ if not target_shape.with_rank_at_least(2)[0].is_compatible_with(
+ source.shape.with_rank_at_least(2)[0]):
+ raise ValueError(
+ "Need first dimension of target shape (%s) and "
+ "source shape (%s) to match." % (target.shape, source.shape))
+ if target_shape.is_fully_defined():
+ batch_size = int(target_shape[0])
+ target_row_size = target_shape.num_elements() // batch_size
+ else:
+ target_shape = array_ops.shape(target)
+ batch_size = target_shape[0]
+ target_row_size = array_ops.size(target) // batch_size
+ source_shape = array_ops.shape(source)
+ # Flatten target to 2-D.
+ # Note that we push and pop the tape here and below. This is needed since we
+ # need gradients through the enclosed operations.
+ self._push_tape()
+ with ops.control_dependencies(
+ [check_ops.assert_equal(batch_size, source_shape[0])]):
+ target = array_ops.reshape(target, [batch_size, target_row_size])
+ self._pop_tape()
+
+ def loop_fn(i):
+ self._push_tape()
+ y = array_ops.gather(target, i, axis=1)
+ self._pop_tape()
+ return self.gradient(y, source,
+ unconnected_gradients=unconnected_gradients)
+
+ if experimental_use_pfor:
+ try:
+ output = pfor_ops.pfor(loop_fn, target_row_size)
+ except ValueError as err:
+ six.reraise(
+ ValueError,
+ ValueError(
+ str(err) + "\nEncountered an exception while vectorizing the "
+ "batch_jacobian computation. Vectorization can be disabled by "
+ "setting experimental_use_pfor to False."),
+ sys.exc_info()[2])
+ else:
+ if context.executing_eagerly() and not self._persistent:
+ raise RuntimeError(
+ "GradientTape must be created with persistent=True"
+ " to compute the batch_jacobian with eager execution enabled and "
+ " with experimental_use_pfor set to False.")
+ output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size)
+ if output is None:
+ return None
+ output = array_ops.reshape(output,
+ [target_row_size, batch_size, -1])
+ output = array_ops.transpose(output, [1, 0, 2])
+ new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
+ return array_ops.reshape(output, new_shape)
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index d9f2a95..08553b9 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -74,7 +74,7 @@
tf_g1 = embedding_ops.embedding_lookup(tf_var, tf_ind1)
tf_g2 = embedding_ops.embedding_lookup(tf_var, tf_ind2)
tf_g3 = embedding_ops.embedding_lookup(tf_var, tf_ind3)
- tf_g4 = math_ops.reduce_sum(tf_var * 2.0, reduction_indices=(0, 1))
+ tf_g4 = math_ops.reduce_sum(tf_var * 2.0, axis=(0, 1))
tf_y = tf_g1 * tf_g2 * tf_g3 * tf_g4
tf_grad = gradients.gradients(tf_y, [tf_var])[0]
@@ -1227,5 +1227,175 @@
self.assertAllEqual(da[0], tf_da[0].eval())
+@test_util.run_all_in_graph_and_eager_modes
+class JacobianTest(test.TestCase):
+
+ def _jacobian(self, experimental_use_pfor):
+ persistent = context.executing_eagerly and not experimental_use_pfor
+ with backprop.GradientTape(persistent=persistent) as g:
+ x = constant_op.constant([1., 2.])
+ y = constant_op.constant([3., 4.])
+ g.watch(x)
+ g.watch(y)
+ z = x * x * y
+ jacobian = g.jacobian(z, [x, y],
+ experimental_use_pfor=experimental_use_pfor)
+ answer = [array_ops.diag(2 * x * y), array_ops.diag(x * x)]
+ return jacobian, answer
+
+ def testPfor(self):
+ jacobian, answer = self._jacobian(experimental_use_pfor=True)
+ for j, a in zip(jacobian, answer):
+ self.assertAllEqual(a, j)
+
+ def testWhileLoop(self):
+ jacobian, answer = self._jacobian(experimental_use_pfor=False)
+ for j, a in zip(jacobian, answer):
+ self.assertAllEqual(a, j)
+
+ def testPforDefun(self):
+
+ @function.defun
+ def _f():
+ return self._jacobian(experimental_use_pfor=True)
+
+ jacobian, answer = _f()
+ for j, a in zip(jacobian, answer):
+ self.assertAllEqual(a, j)
+
+ def testWhileLoopDefun(self):
+
+ @function.defun
+ def _f():
+ return self._jacobian(experimental_use_pfor=False)
+
+ jacobian, answer = _f()
+ for j, a in zip(jacobian, answer):
+ self.assertAllEqual(a, j)
+
+ def testPersistentTape(self):
+ if not context.executing_eagerly():
+ return
+ with backprop.GradientTape() as g:
+ x = constant_op.constant([1.0, 2.0])
+ g.watch(x)
+ y = x * x
+ with self.assertRaisesRegexp(RuntimeError, 'persistent'):
+ g.jacobian(y, x, experimental_use_pfor=False)
+
+ def testPforException(self):
+ var = variables.Variable([1.])
+
+ @custom_gradient.custom_gradient
+ def op(x):
+ def grad(_):
+ # Note that we perform a stateful operation here that will not be
+ # compatible with parallel for construct.
+ with ops.control_dependencies(
+ [var.assign(random_ops.random_uniform([1]))]):
+ return constant_op.constant(1.)
+ return x, grad
+
+ with backprop.GradientTape() as g:
+ x = constant_op.constant([1., 2.])
+ g.watch(x)
+ y = op(x)
+ with self.assertRaisesRegexp(ValueError, 'No converter'):
+ g.jacobian(y, x, experimental_use_pfor=True)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class BatchJacobianTest(test.TestCase):
+
+ def _batch_jacobian(self, experimental_use_pfor):
+ persistent = context.executing_eagerly and not experimental_use_pfor
+ with backprop.GradientTape(persistent=persistent) as g:
+ x = constant_op.constant([[1., 2.], [3., 4.]])
+ y = constant_op.constant([[3., 4.], [5., 6.]])
+ g.watch(x)
+ z = x * x * y
+ batch_jacobian = g.batch_jacobian(
+ z, x, experimental_use_pfor=experimental_use_pfor)
+ answer = array_ops.stack([array_ops.diag(2 * x[0] * y[0]),
+ array_ops.diag(2 * x[1] * y[1])])
+ return batch_jacobian, answer
+
+ def testPfor(self):
+ batch_jacobian, answer = self._batch_jacobian(experimental_use_pfor=True)
+ self.assertAllEqual(answer, batch_jacobian)
+
+ def testWhileLoop(self):
+ batch_jacobian, answer = self._batch_jacobian(experimental_use_pfor=False)
+ self.assertAllEqual(answer, batch_jacobian)
+
+ def testPforDefun(self):
+
+ @function.defun
+ def _f():
+ return self._batch_jacobian(experimental_use_pfor=True)
+
+ batch_jacobian, answer = _f()
+ self.assertAllEqual(answer, batch_jacobian)
+
+ def testWhileLoopDefun(self):
+
+ @function.defun
+ def _f():
+ return self._batch_jacobian(experimental_use_pfor=False)
+
+ batch_jacobian, answer = _f()
+ self.assertAllEqual(answer, batch_jacobian)
+
+ def testPersistentTape(self):
+ if not context.executing_eagerly():
+ return
+ with backprop.GradientTape() as g:
+ x = constant_op.constant([[1.0, 2.0]])
+ g.watch(x)
+ y = x * x
+ with self.assertRaisesRegexp(RuntimeError, 'persistent'):
+ g.batch_jacobian(y, x, experimental_use_pfor=False)
+
+ def testBadShape(self):
+ x = random_ops.random_uniform([2, 3])
+ with backprop.GradientTape() as g:
+ y = array_ops.concat([x, x], axis=0)
+ with self.assertRaisesRegexp(ValueError, 'Need first dimension'):
+ g.batch_jacobian(y, x)
+
+ def testBadInputRank(self):
+ x = random_ops.random_uniform([2])
+ with backprop.GradientTape() as g:
+ y = random_ops.random_uniform([2, 2])
+ with self.assertRaisesRegexp(ValueError, 'must have rank at least 2'):
+ g.batch_jacobian(y, x)
+
+ def testBadOutputRank(self):
+ x = random_ops.random_uniform([2, 2])
+ with backprop.GradientTape() as g:
+ y = random_ops.random_uniform([2])
+ with self.assertRaisesRegexp(ValueError, 'must have rank at least 2'):
+ g.batch_jacobian(y, x)
+
+ def testPforException(self):
+ var = variables.Variable([1.])
+
+ @custom_gradient.custom_gradient
+ def op(x):
+ def grad(_):
+ # Note that we perform a stateful operation here that will not be
+ # compatible with parallel for construct.
+ with ops.control_dependencies(
+ [var.assign(random_ops.random_uniform([1]))]):
+ return constant_op.constant(1.)
+ return x, grad
+
+ with backprop.GradientTape() as g:
+ x = constant_op.constant([[1.], [2.]])
+ g.watch(x)
+ y = op(x)
+ with self.assertRaisesRegexp(ValueError, 'No converter'):
+ g.batch_jacobian(y, x, experimental_use_pfor=True)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 8867158..31a7efc 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -80,7 +80,6 @@
def __init__(self, initializer="ones"):
super(SubclassedKerasModel, self).__init__()
- self._can_use_graph_functions = True
self.layer_a = keras.layers.Dense(
64, kernel_initializer=initializer, bias_initializer="zeros")
self.layer_b = keras.layers.Dense(
@@ -733,38 +732,38 @@
assert np.equal(func(), make_keras_model()(data)).all()
self._run(func, 30000)
- def _benchmark_keras_model_fit(self, model):
+ def _benchmark_keras_model_fit(self, model, run_eagerly=False):
data = random_ops.random_uniform((10, 10), minval=-1, maxval=1)
labels = random_ops.random_uniform((10, 10), minval=-1, maxval=1)
dataset = dataset_ops.Dataset.from_tensors((data, labels)).repeat()
model.compile(
gradient_descent.GradientDescentOptimizer(learning_rate=0.001),
- loss="mse")
+ loss="mse", run_eagerly=run_eagerly)
func = lambda: model.fit(dataset, epochs=1, steps_per_epoch=1000, verbose=0)
# First call is more expensive (creates variables etc.), discount that.
model.fit(dataset, epochs=1, steps_per_epoch=1, verbose=0)
self._run(func, 1)
- def _benchmark_keras_model_evaluate(self, model):
+ def _benchmark_keras_model_evaluate(self, model, run_eagerly=False):
data = random_ops.random_uniform((10, 10), minval=-1, maxval=1)
labels = random_ops.random_uniform((10, 10), minval=-1, maxval=1)
dataset = dataset_ops.Dataset.from_tensors((data, labels)).repeat()
model.compile(
gradient_descent.GradientDescentOptimizer(learning_rate=0.001),
- loss="mse")
+ loss="mse", run_eagerly=run_eagerly)
func = lambda: model.evaluate(dataset, steps=1000, verbose=0)
# First call is more expensive (creates variables etc.), discount that.
model.evaluate(dataset, steps=1, verbose=0)
self._run(func, 1)
- def _benchmark_keras_model_predict(self, model):
+ def _benchmark_keras_model_predict(self, model, run_eagerly=False):
data = random_ops.random_uniform((10, 10), minval=-1, maxval=1)
dataset = dataset_ops.Dataset.from_tensors(tuple([data])).repeat()
model.compile(
gradient_descent.GradientDescentOptimizer(learning_rate=0.001),
- loss="mse")
+ loss="mse", run_eagerly=run_eagerly)
func = lambda: model.predict(dataset, steps=1000, verbose=0)
# First call is more expensive (creates variables etc.), discount that.
model.predict(dataset, steps=1, verbose=0)
@@ -780,10 +779,9 @@
model = SubclassedKerasModel(initializer="glorot_uniform")
self._benchmark_keras_model_fit(model)
- def benchmark_keras_model_subclassed_fit_disable_defun(self):
+ def benchmark_keras_model_subclassed_fit_run_model_eagerly(self):
model = SubclassedKerasModel(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_fit(model)
+ self._benchmark_keras_model_fit(model, run_eagerly=True)
def benchmark_keras_model_functional_fit(self):
model = make_keras_model(initializer="glorot_uniform")
@@ -794,10 +792,9 @@
model = make_keras_model(initializer="glorot_uniform")
self._benchmark_keras_model_fit(model)
- def benchmark_keras_model_functional_fit_disable_defun(self):
+ def benchmark_keras_model_functional_fit_run_model_eagerly(self):
model = make_keras_model(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_fit(model)
+ self._benchmark_keras_model_fit(model, run_eagerly=True)
def benchmark_keras_model_sequential_fit(self):
model = make_sequential_keras_model(initializer="glorot_uniform")
@@ -808,64 +805,57 @@
model = make_sequential_keras_model(initializer="glorot_uniform")
self._benchmark_keras_model_fit(model)
- def benchmark_keras_model_sequential_fit_disable_defun(self):
+ def benchmark_keras_model_sequential_fit_run_model_eagerly(self):
model = make_sequential_keras_model(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_fit(model)
+ self._benchmark_keras_model_fit(model, run_eagerly=True)
def benchmark_keras_model_subclassed_evaluate(self):
model = SubclassedKerasModel(initializer="glorot_uniform")
self._benchmark_keras_model_evaluate(model)
- def benchmark_keras_model_subclassed_evaluate_disable_defun(self):
+ def benchmark_keras_model_subclassed_evaluate_run_model_eagerly(self):
model = SubclassedKerasModel(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_evaluate(model)
+ self._benchmark_keras_model_evaluate(model, run_eagerly=True)
def benchmark_keras_model_functional_evaluate(self):
model = make_keras_model(initializer="glorot_uniform")
self._benchmark_keras_model_evaluate(model)
- def benchmark_keras_model_functional_evaluate_disable_defun(self):
+ def benchmark_keras_model_functional_evaluate_run_model_eagerly(self):
model = make_keras_model(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_evaluate(model)
+ self._benchmark_keras_model_evaluate(model, run_eagerly=True)
def benchmark_keras_model_sequential_evaluate(self):
model = make_sequential_keras_model(initializer="glorot_uniform")
self._benchmark_keras_model_evaluate(model)
- def benchmark_keras_model_sequential_evaluate_disable_defun(self):
+ def benchmark_keras_model_sequential_evaluate_run_model_eagerly(self):
model = make_sequential_keras_model(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_evaluate(model)
+ self._benchmark_keras_model_evaluate(model, run_eagerly=True)
def benchmark_keras_model_subclassed_predict(self):
model = SubclassedKerasModel(initializer="glorot_uniform")
self._benchmark_keras_model_predict(model)
- def benchmark_keras_model_subclassed_predict_disable_defun(self):
+ def benchmark_keras_model_subclassed_predict_run_model_eagerly(self):
model = SubclassedKerasModel(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_predict(model)
+ self._benchmark_keras_model_predict(model, run_eagerly=True)
def benchmark_keras_model_functional_predict(self):
model = make_keras_model(initializer="glorot_uniform")
self._benchmark_keras_model_predict(model)
- def benchmark_keras_model_functional_predict_disable_defun(self):
+ def benchmark_keras_model_functional_predict_run_model_eagerly(self):
model = make_keras_model(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_predict(model)
+ self._benchmark_keras_model_predict(model, run_eagerly=True)
def benchmark_keras_model_sequential_predict(self):
model = make_sequential_keras_model(initializer="glorot_uniform")
self._benchmark_keras_model_predict(model)
- def benchmark_keras_model_sequential_predict_disable_defun(self):
+ def benchmark_keras_model_sequential_predict_run_model_eagerly(self):
model = make_sequential_keras_model(initializer="glorot_uniform")
- model._can_use_graph_functions = False
- self._benchmark_keras_model_predict(model)
+ self._benchmark_keras_model_predict(model, run_eagerly=True)
def benchmarkScan(self):
elems = math_ops.range(1600)
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 52830d4..6bacd7a 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -552,9 +552,9 @@
return x + tf.to_float(c)
assert int(c) == 0
- assert f(1.0) == 3.0
+ assert f(1.0) == 2.0
assert int(c) == 1
- assert f(1.0) == 4.0
+ assert f(1.0) == 3.0
assert int(c) == 2
```
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index e863cf5..68cdb1a 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -48,6 +48,7 @@
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -66,6 +67,11 @@
BACKWARD_FUNCTION_ATTRIBUTE_NAME
]
+CacheKey = collections.namedtuple("CacheKey", [
+ "input_signature", "parent_graph", "device_functions", "colocation_stack",
+ "uses_xla"
+])
+
def _parse_func_attrs(attributes):
"""Convert the keyword arguments into function_def attributes.
@@ -83,8 +89,8 @@
"""
attrs = {}
for key, value in attributes.items():
- if not any([re.match(reg, key)
- for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX]):
+ if not any(re.match(reg, key)
+ for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX):
raise ValueError("Attribute name is not whitelisted. "
"Whitelisted: prefix %s, got: %s" %
(WHITELIST_FUNCTION_ATTRIBUTE_REGEX, key))
@@ -418,7 +424,10 @@
if (tape.should_record(tensor_inputs) or
tape.should_record(self._captured_inputs)):
- return self._backprop_call(args)
+ if context.executing_eagerly():
+ return self._eager_backprop_call(args)
+ else:
+ return self._backprop_call_with_delayed_rewrite(args)
# Only need to override the gradient in graph mode and when we have outputs.
if context.executing_eagerly() or not self.outputs:
@@ -444,37 +453,40 @@
name: The name to register the gradient as.
"""
@ops.RegisterGradient(name)
- def grad_fn(op, *doutputs): # pylint: disable=unused-variable
- """Gradients of this function."""
- if self._backward_graph_function is None:
- self._construct_backprop_function()
+ def _registered_grad_fn(op, *doutputs): # pylint: disable=unused-variable
+ return self._grad_fn(op, *doutputs)
- # pylint: disable=protected-access
- self._forward_function.add_to_graph(op.graph)
- num_inference_outputs = self._inference_function._num_outputs
+ def _grad_fn(self, op, *doutputs):
+ """Gradients of this function."""
+ if self._backward_graph_function is None:
+ self._construct_backprop_function()
- # Rewrite an inference call op to be a forward call op
- if op.get_attr("f").name.encode() == self._inference_function.name:
- func = attr_value_pb2.AttrValue(
- func=attr_value_pb2.NameAttrList(
- name=self._forward_function.name))
- op._set_attr("f", func)
- types = attr_value_pb2.AttrValue.ListValue(
- type=self._forward_function._output_types)
- op._set_attr("Tout", attr_value_pb2.AttrValue(list=types))
- for i in range(
- num_inference_outputs, len(self._forward_function._output_types)):
- t = ops.Tensor(op, i, self._forward_function._output_types[i])
- t.set_shape(self._forward_function._output_shapes[i])
- func_graph_output = self._forward_function._func_graph_outputs[i]
- custom_gradient.copy_handle_data(func_graph_output, t)
- op._outputs.append(t)
- # pylint: enable=protected-access
- # Compute the gradients using the side outputs
- side_outputs = op.outputs[num_inference_outputs:]
- args = list(doutputs[:num_inference_outputs]) + list(side_outputs)
- return self._backward_graph_function._call_flat( # pylint: disable=protected-access
- (a for a in args if a is not None))
+ # pylint: disable=protected-access
+ self._forward_function.add_to_graph(op.graph)
+ num_inference_outputs = self._inference_function._num_outputs
+
+ # Rewrite an inference call op to be a forward call op
+ if op.get_attr("f").name.encode() == self._inference_function.name:
+ func = attr_value_pb2.AttrValue(
+ func=attr_value_pb2.NameAttrList(
+ name=self._forward_function.name))
+ op._set_attr("f", func)
+ types = attr_value_pb2.AttrValue.ListValue(
+ type=self._forward_function._output_types)
+ op._set_attr("Tout", attr_value_pb2.AttrValue(list=types))
+ for i in range(
+ num_inference_outputs, len(self._forward_function._output_types)):
+ t = ops.Tensor(op, i, self._forward_function._output_types[i])
+ t.set_shape(self._forward_function._output_shapes[i])
+ func_graph_output = self._forward_function._func_graph_outputs[i]
+ custom_gradient.copy_handle_data(func_graph_output, t)
+ op._outputs.append(t)
+ # pylint: enable=protected-access
+ # Compute the gradients using the side outputs
+ side_outputs = op.outputs[num_inference_outputs:]
+ args = list(doutputs[:num_inference_outputs]) + list(side_outputs)
+ return self._backward_graph_function._call_flat( # pylint: disable=protected-access
+ (a for a in args if a is not None))
@property
def name(self):
@@ -617,10 +629,13 @@
self._func_graph.outputs + backwards_graph_captures,
forward_function_attr)
- def _backprop_call(self, args):
+ def _eager_backprop_call(self, args):
"""Calls the forward function and records the result on a tape.
- (Only records results on a tape if the function has outputs)
+ This method fully constructs the forward and backward functions before
+ calling the function and recording them on the tape.
+
+ (Only records results on a tape if the function has outputs).
Args:
args: All inputs to the function, including resolved captured inputs
@@ -662,6 +677,46 @@
args, backward_function)
return self._build_call_outputs(real_outputs)
+ def _backprop_call_with_delayed_rewrite(self, args):
+ """Calls the inference function and records the result on a tape.
+
+ The recorded backwards function will construct the backwards graph and
+ rewrite the inference function to the forward function. This only happens
+ if the recorded backwards function ends up being used to compute gradients.
+
+ This approach avoids constructing unnecessary graphs, but it only works if
+ we are calling this function when not executing eagerly.
+
+ (Only records results on a tape if the function has outputs)
+
+ Args:
+ args: All inputs to the function, including resolved captured inputs
+
+ Returns:
+ The call output.
+ """
+ ctx = context.context()
+
+ if not self._gradient_name:
+ self._gradient_name = "PartitionedCall-%s" % ops.uid()
+ self._register_gradient(self._gradient_name)
+ with ops.get_default_graph().gradient_override_map(
+ {"PartitionedCall": self._gradient_name,
+ "StatefulPartitionedCall": self._gradient_name}):
+ outputs = self._inference_function.call(ctx, args)
+
+ if isinstance(outputs, ops.Operation) or outputs is None:
+ return outputs
+
+ call_op = outputs[0].op
+
+ def backward_function(*args):
+ return self._grad_fn(call_op, *args)
+
+ tape.record_operation(self._inference_function.signature.name, outputs,
+ args, backward_function)
+ return self._build_call_outputs(outputs)
+
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -927,17 +982,17 @@
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwargs) if kwargs else args
- cache_key = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
+ input_signature = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
else:
del args, kwargs
- cache_key = self._flat_input_signature
+ input_signature = self._flat_input_signature
ctx = context.context()
with ops.init_scope():
# The graph, or whether we're executing eagerly, should be a part of the
# cache key so we don't improperly capture tensors such as variables.
executing_eagerly = ctx.executing_eagerly()
- execution_context = executing_eagerly or ops.get_default_graph()
+ parent_graph = None if executing_eagerly else ops.get_default_graph()
# pylint: disable=protected-access
default_graph = ops.get_default_graph()
@@ -966,8 +1021,8 @@
else:
device_functions = ()
# pylint: enable=protected-access
- return (cache_key, execution_context, device_functions, colocation_stack,
- uses_xla)
+ return CacheKey(input_signature, parent_graph, device_functions,
+ colocation_stack, uses_xla)
def _canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
@@ -1083,6 +1138,9 @@
"must be hashable.")
if graph_function is None:
+ logging.vlog(1,
+ "Creating new FuncGraph for Python function %r (key: %r)",
+ self._python_function, cache_key)
if self._input_signature is None:
arglen = len(args)
else:
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index 55f0896..0d0f70d 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -439,8 +439,8 @@
PyErr_SetString(
PyExc_TypeError,
tensorflow::strings::StrCat(
- "Cannot convert value ", TFE_GetPythonString(value_str.get()),
- " to EagerTensor with requested dtype: ",
+ "Cannot convert provided value to EagerTensor. Provided value: ",
+ TFE_GetPythonString(value_str.get()), " Requested dtype: ",
tensorflow::DataTypeString(
static_cast<tensorflow::DataType>(desired_dtype)))
.c_str());
@@ -672,11 +672,29 @@
#endif
}
+// Getter `backing_device`.
+static PyObject* EagerTensor_backing_device(EagerTensor* self) {
+ const char* device =
+ TFE_TensorHandleBackingDeviceName(self->handle, self->status);
+ if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
+ // Cleanup self->status before returning.
+ TF_SetStatus(self->status, TF_OK, "");
+ return nullptr;
+ }
+#if PY_MAJOR_VERSION >= 3
+ return PyUnicode_FromString(device);
+#else
+ return PyBytes_FromString(device);
+#endif
+}
+
static PyGetSetDef EagerTensor_getseters[] = {
{const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
const_cast<char*>("_id"), nullptr},
{const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
const_cast<char*>("device"), nullptr},
+ {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
+ nullptr, const_cast<char*>("backing_device"), nullptr},
{const_cast<char*>("_handle_data"), (getter)EagerTensor_tensor_handle,
(setter)EagerTensor_settensor_handle, const_cast<char*>("_tensor_handle"),
nullptr},
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index f074b73..9ce500b 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -2303,8 +2303,10 @@
PyErr_SetString(
PyExc_TypeError,
tensorflow::strings::StrCat(
- "Cannot convert value ", TFE_GetPythonString(input_str.get()),
- " to EagerTensor with requested dtype: ", desired_dtype)
+ "Cannot convert provided value to EagerTensor. Provided value: ",
+ TFE_GetPythonString(input_str.get()), " Requested dtype: ",
+ tensorflow::DataTypeString(
+ static_cast<tensorflow::DataType>(desired_dtype)))
.c_str());
return false;
}
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index 1326f09..e501b40 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -63,7 +63,7 @@
"""Marks this variable to be watched by the given tape."""
strategy = distribution_strategy_context.get_distribution_strategy()
if distribution_strategy_context.get_replica_context():
- variables = [strategy.value_container(variable)]
+ variables = [strategy.extended.value_container(variable)]
else:
variables = strategy.unwrap(variable)
for var in variables:
@@ -78,7 +78,7 @@
"""
strategy = distribution_strategy_context.get_distribution_strategy()
if distribution_strategy_context.get_replica_context():
- variables = [strategy.value_container(variable)]
+ variables = [strategy.extended.value_container(variable)]
else:
variables = strategy.unwrap(variable)
for var in variables:
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index d0500a4..8c9d5da 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -323,6 +323,14 @@
def testConvertToTensorAllowsOverflow(self):
_ = ops.convert_to_tensor(123456789, dtype=dtypes.uint8)
+ def testEagerTensorError(self):
+ with self.assertRaisesRegexp(
+ TypeError,
+ "Cannot convert provided value to EagerTensor. "
+ "Provided value.*Requested dtype.*"):
+ _ = ops.convert_to_tensor(1., dtype=dtypes.int32)
+
+
class TFETensorUtilTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 59828de..2af2b9f 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -1435,7 +1435,7 @@
return HashedCategoricalColumn(key, hash_bucket_size, dtype)
-@tf_export('feature_column.categorical_column_with_vocabulary_file')
+@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file'])
def categorical_column_with_vocabulary_file(key,
vocabulary_file,
vocabulary_size=None,
@@ -1520,6 +1520,97 @@
ValueError: `num_oov_buckets` and `default_value` are both specified.
ValueError: `dtype` is neither string nor integer.
"""
+ return categorical_column_with_vocabulary_file_v2(
+ key, vocabulary_file, vocabulary_size,
+ dtype, default_value,
+ num_oov_buckets)
+
+
+@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[])
+def categorical_column_with_vocabulary_file_v2(key,
+ vocabulary_file,
+ vocabulary_size=None,
+ dtype=dtypes.string,
+ default_value=None,
+ num_oov_buckets=0):
+ """A `CategoricalColumn` with a vocabulary file.
+
+ Use this when your inputs are in string or integer format, and you have a
+ vocabulary file that maps each value to an integer ID. By default,
+ out-of-vocabulary values are ignored. Use either (but not both) of
+ `num_oov_buckets` and `default_value` to specify how to include
+ out-of-vocabulary values.
+
+ For input dictionary `features`, `features[key]` is either `Tensor` or
+ `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int
+ and `''` for string, which will be dropped by this feature column.
+
+ Example with `num_oov_buckets`:
+ File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state
+ abbreviation. All inputs with values in that file are assigned an ID 0-49,
+ corresponding to its line number. All other values are hashed and assigned an
+ ID 50-54.
+
+ ```python
+ states = categorical_column_with_vocabulary_file(
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=50,
+ num_oov_buckets=5)
+ columns = [states, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction = linear_model(features, columns)
+ ```
+
+ Example with `default_value`:
+ File '/us/states.txt' contains 51 lines - the first line is 'XX', and the
+ other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX'
+ in input, and other values missing from the file, will be assigned ID 0. All
+ others are assigned the corresponding line number 1-50.
+
+ ```python
+ states = categorical_column_with_vocabulary_file(
+ key='states', vocabulary_file='/us/states.txt', vocabulary_size=51,
+ default_value=0)
+ columns = [states, ...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ linear_prediction, _, _ = linear_model(features, columns)
+ ```
+
+ And to make an embedding with either:
+
+ ```python
+ columns = [embedding_column(states, 3),...]
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ dense_tensor = input_layer(features, columns)
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ vocabulary_file: The vocabulary file name.
+ vocabulary_size: Number of the elements in the vocabulary. This must be no
+ greater than length of `vocabulary_file`, if less than length, later
+ values are ignored. If None, it is set to the length of `vocabulary_file`.
+ dtype: The type of features. Only string and integer types are supported.
+ default_value: The integer ID value to return for out-of-vocabulary feature
+ values, defaults to `-1`. This can not be specified with a positive
+ `num_oov_buckets`.
+ num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
+ buckets. All out-of-vocabulary inputs will be assigned IDs in the range
+ `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
+ the input value. A positive `num_oov_buckets` can not be specified with
+ `default_value`.
+
+ Returns:
+ A `CategoricalColumn` with a vocabulary file.
+
+ Raises:
+ ValueError: `vocabulary_file` is missing or cannot be opened.
+ ValueError: `vocabulary_size` is missing or < 1.
+ ValueError: `num_oov_buckets` is a negative integer.
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
+ ValueError: `dtype` is neither string nor integer.
+ """
if not vocabulary_file:
raise ValueError('Missing vocabulary_file in {}.'.format(key))
diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py
index 53d84b2..ade0797 100644
--- a/tensorflow/python/framework/constant_op.py
+++ b/tensorflow/python/framework/constant_op.py
@@ -114,8 +114,9 @@
return ops.EagerTensor(value, handle, device, dtype)
-@tf_export("constant")
-def constant(value, dtype=None, shape=None, name="Const", verify_shape=False):
+@tf_export(v1=["constant"])
+def constant_v1(
+ value, dtype=None, shape=None, name="Const", verify_shape=False):
"""Creates a constant tensor.
The resulting tensor is populated with values of type `dtype`, as
@@ -174,6 +175,79 @@
Raises:
TypeError: if shape is incorrectly specified or unsupported.
"""
+ return _constant_impl(value, dtype, shape, name, verify_shape=verify_shape,
+ allow_broadcast=False)
+
+
+@tf_export("constant", v1=[])
+def constant(value, dtype=None, shape=None, name="Const"):
+ """Creates a constant tensor.
+
+ The resulting tensor is populated with values of type `dtype`, as
+ specified by arguments `value` and (optionally) `shape` (see examples
+ below).
+
+ The argument `value` can be a constant value, or a list of values of type
+ `dtype`. If `value` is a list, then the length of the list must be less
+ than or equal to the number of elements implied by the `shape` argument (if
+ specified). In the case where the list length is less than the number of
+ elements specified by `shape`, the last element in the list will be used
+ to fill the remaining entries.
+
+ The argument `shape` is optional. If present, it specifies the dimensions of
+ the resulting tensor. If not present, the shape of `value` is used.
+
+ If the argument `dtype` is not specified, then the type is inferred from
+ the type of `value`.
+
+ For example:
+
+ ```python
+ # Constant 1-D Tensor populated with value list.
+ tensor = tf.constant([1, 2, 3, 4, 5, 6]) => [1 2 3 4 5 6]
+
+ # Constant 1-D Tensor populated with value list.
+ tensor = tf.constant([1, 2, 3, 4, 5, 6], shape=(2,3))
+ => [[1 2 3], [4 5 6]]
+
+ # Constant 2-D tensor populated with scalar value -1.
+ tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.]
+ [-1. -1. -1.]]
+ ```
+
+ `tf.constant` differs from `tf.fill` in a few ways:
+
+ * `tf.constant` supports arbitrary constants, not just uniform scalar
+ Tensors like `tf.fill`.
+ * `tf.constant` creates a `Const` node in the computation graph with the
+ exact value at graph construction time. On the other hand, `tf.fill`
+ creates an Op in the graph that is expanded at runtime.
+ * Because `tf.constant` only embeds constant values in the graph, it does
+ not support dynamic shapes based on other runtime Tensors, whereas
+ `tf.fill` does.
+
+ Args:
+ value: A constant value (or list) of output type `dtype`.
+
+ dtype: The type of the elements of the resulting tensor.
+
+ shape: Optional dimensions of resulting tensor.
+
+ name: Optional name for the tensor.
+
+ Returns:
+ A Constant Tensor.
+
+ Raises:
+ TypeError: if shape is incorrectly specified or unsupported.
+ """
+ return _constant_impl(value, dtype, shape, name, verify_shape=False,
+ allow_broadcast=True)
+
+
+def _constant_impl(
+ value, dtype, shape, name, verify_shape, allow_broadcast):
+ """Implementation of constant."""
ctx = context.context()
if ctx.executing_eagerly():
t = convert_to_eager_tensor(value, ctx, dtype)
@@ -205,7 +279,8 @@
tensor_value = attr_value_pb2.AttrValue()
tensor_value.tensor.CopyFrom(
tensor_util.make_tensor_proto(
- value, dtype=dtype, shape=shape, verify_shape=verify_shape))
+ value, dtype=dtype, shape=shape, verify_shape=verify_shape,
+ allow_broadcast=allow_broadcast))
dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
const_tensor = g.create_op(
"Const", [], [dtype_value.type],
diff --git a/tensorflow/python/framework/file_system_test.py b/tensorflow/python/framework/file_system_test.py
index 6901715..066d34e 100644
--- a/tensorflow/python/framework/file_system_test.py
+++ b/tensorflow/python/framework/file_system_test.py
@@ -42,7 +42,7 @@
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queue.enqueue_many([["test://foo"]]).run()
queue.close().run()
- key, value = sess.run(reader.read(queue))
+ key, value = self.evaluate(reader.read(queue))
self.assertEqual(key, compat.as_bytes("test://foo"))
self.assertEqual(value, compat.as_bytes("AAAAAAAAAA"))
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 230a554..622686c 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -874,7 +874,7 @@
# If func only returned one value, make it a tuple.
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
- if any([_ is None for _ in outputs]):
+ if any(_ is None for _ in outputs):
raise ValueError("Function %s can not return None." % name)
# Ensures each output is a Tensor in the function graph.
outputs = [ops.convert_to_tensor(t) for t in outputs]
@@ -1190,7 +1190,7 @@
def _type_list_to_str(types):
- if any([_ not in _DTYPE_TO_STR for _ in types]):
+ if any(_ not in _DTYPE_TO_STR for _ in types):
raise ValueError("Unsupported dtypes: %s" % types)
return "".join([_DTYPE_TO_STR[_] for _ in types])
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 4d1aabd..1803cb9 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -174,7 +174,9 @@
# Update inputs of all nodes in graph.
for node_def in graph_def.node:
for i in range(len(node_def.input)):
- node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
+ # TODO(apassos): how can it not be there?
+ if node_def.input[i] in nested_to_flat_tensor_name:
+ node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
return graph_def, nested_to_flat_tensor_name
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 90deb97..1a17a48 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -365,7 +365,7 @@
else:
dx, dy = gradients_impl.gradients([z], [x, y])
with session.Session() as sess:
- dx_val, dy_val = sess.run([dx, dy])
+ dx_val, dy_val = self.evaluate([dx, dy])
self.assertEqual([2.0], dx_val)
self.assertEqual([0.0], dy_val)
@@ -854,7 +854,7 @@
z = Bar(x)
with self.session(graph=g) as sess:
- v0, v1 = sess.run([y, z])
+ v0, v1 = self.evaluate([y, z])
self.assertAllEqual(v0, 20.)
self.assertAllEqual(v1, 20.)
@@ -1127,7 +1127,7 @@
dx2, = gradients_impl.gradients(ys=[y2], xs=[x2])
with self.session(graph=g) as sess:
- v0, v1, v2 = sess.run([dx0, dx1, dx2])
+ v0, v1, v2 = self.evaluate([dx0, dx1, dx2])
self.assertAllEqual(v0, 2.)
self.assertAllEqual(v1, 101.)
@@ -1532,7 +1532,7 @@
tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
len(str(gdef)), len(gdef.SerializeToString()))
with g.as_default(), session.Session(config=cfg) as sess:
- return sess.run(m)
+ return self.evaluate(m)
mv0 = RunForward("complete")
for cfg in _OptimizerOptions():
@@ -1561,7 +1561,7 @@
tf_logging.info("time: %f txt size: %d gdef bin size: %d", finish - start,
len(str(gdef)), len(gdef.SerializeToString()))
with g.as_default(), session.Session(config=cfg) as sess:
- return sess.run(dw)
+ return self.evaluate(dw)
d0 = RunForwardBackward("complete")
for cfg in _OptimizerOptions():
@@ -1705,7 +1705,7 @@
with self.session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
- w, b, x, y0, loss, dw, db = sess.run([w, b, x, y0, loss, dw, db])
+ w, b, x, y0, loss, dw, db = self.evaluate([w, b, x, y0, loss, dw, db])
self.assertAllEqual(w.shape, (64, 64))
self.assertAllClose(np.sum(w), 2050.44)
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index 7a9f2e8..10a01c7 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -210,7 +210,7 @@
with session.Session() as sess:
init = variables.variables_initializer([variable_node])
- sess.run(init)
+ self.evaluate(init)
output = self.evaluate(output_node)
self.assertNear(4.0, output, 0.00001)
variable_graph_def = sess.graph.as_graph_def()
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index a57f0b3..66e80b5 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -397,7 +397,7 @@
# Run the imported graph.
# TODO(b/76173421): make this work (currently DCHECKS)
# with self.cached_session() as sess:
- # sess.run(imported_init)
+ # self.evaluate(imported_init)
# self.assertEqual(self.evaluate(imported_var), 1.0)
# self.assertEqual(self.evaluate(imported_assign), 2.0)
# self.assertEqual(list(self.evaluate(imported_shape)), [])
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 3605ed7..cc93f8b 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -545,7 +545,7 @@
name="")
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
- sess.run(x)
+ self.evaluate(x)
def testScopedImportUnderNameScope(self):
graph = ops.Graph()
@@ -600,11 +600,11 @@
with graph.as_default():
variables.Variable(initial_value=1.0, trainable=True)
self.assertTrue(
- all([
+ all(
graph.get_collection(key)
for key in
[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES]
- ]))
+ ))
meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)
@@ -868,7 +868,7 @@
_, update_op = metrics.mean(values)
initializer = variables.local_variables_initializer()
- sess.run(initializer)
+ self.evaluate(initializer)
self.evaluate(update_op)
meta_graph.export_scoped_meta_graph(
@@ -880,7 +880,7 @@
with self.session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(meta_graph_filename)
initializer = variables.local_variables_initializer()
- sess.run(initializer)
+ self.evaluate(initializer)
# Verifies that importing an old meta_graph where "local_variables"
# collection is of node_list type works, but cannot build initializer
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index 9955a9a..2318b32 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -570,7 +570,7 @@
"than minimum length %d." %
(input_name, op_type_name, len(values), num_attr.minimum))
# All tensors must have the same base type.
- if any([bt != base_types[0] for bt in base_types]):
+ if any(bt != base_types[0] for bt in base_types):
raise TypeError(
"All tensors passed to '%s' of '%s' Op "
"must have the same type." %
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index c465d2b..5a8a2a4 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -36,6 +36,7 @@
from tensorflow.core.framework import versions_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as c_api
+from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.eager import core
from tensorflow.python.eager import tape
@@ -747,6 +748,18 @@
def _numpy(self):
raise NotImplementedError()
+ @property
+ def backing_device(self):
+ """Returns the name of the device holding this tensor's memory.
+
+ `.backing_device` is usually the same as `.device`, which returns
+ the device on which the kernel of the operation that produced this tensor
+ ran. However, some operations can produce tensors on a different device
+ (e.g., an operation that executes on the GPU but produces output tensors
+ in host memory).
+ """
+ raise NotImplementedError()
+
def __copy__(self):
# Eager Tensors are immutable so it's safe to return themselves as a copy.
return self
@@ -1023,12 +1036,12 @@
`preferred_dtype` is not possible, this argument has no effect.
Returns:
- An `Output` based on `value`.
+ An `Tensor` based on `value`.
Raises:
- TypeError: If no conversion function is registered for `value`.
+ TypeError: If no conversion function is registered for `value` to `dtype`.
RuntimeError: If a registered conversion function returns an invalid value.
-
+ ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
"""
return convert_to_tensor_v2(value, dtype, preferred_dtype, name)
@@ -1076,12 +1089,12 @@
name: Optional name to use if a new `Tensor` is created.
Returns:
- An `Output` based on `value`.
+ An `Tensor` based on `value`.
Raises:
- TypeError: If no conversion function is registered for `value`.
+ TypeError: If no conversion function is registered for `value` to `dtype`.
RuntimeError: If a registered conversion function returns an invalid value.
-
+ ValueError: If the `value` is a tensor not of given `dtype` in graph mode.
"""
return internal_convert_to_tensor(
value=value,
@@ -1102,42 +1115,7 @@
preferred_dtype=None,
ctx=None,
accept_symbolic_tensors=True):
- """Converts the given `value` to an `Tensor`.
-
- This function converts Python objects of various types to `Tensor`
- objects. It accepts `Tensor` objects, numpy arrays, Python lists,
- and Python scalars. For example:
-
- This function can be useful when composing a new operation in Python
- All standard Python op constructors apply this function to each of their
- Tensor-valued inputs, which allows those ops to accept numpy arrays, Python
- lists, and scalars in addition to `Tensor` objects.
-
- Args:
- value: An object whose type has a registered `Tensor` conversion function.
- dtype: Optional element type for the returned tensor. If missing, the
- type is inferred from the type of `value`.
- name: Optional name to use if a new `Tensor` is created.
- as_ref: True if we want the mutable view of Variables, if applicable.
- preferred_dtype: Optional element type for the returned tensor,
- used when dtype is None. In some cases, a caller may not have a
- dtype in mind when converting to a tensor, so preferred_dtype
- can be used as a soft preference. If the conversion to
- `preferred_dtype` is not possible, this argument has no effect.
- ctx: Optional: The value of context.context().
- accept_symbolic_tensors: Whether Keras graph tensors should be accepted as
- a valid tensor type during eager execution.
- If False, this function will raise an exception if it is passed such
- a tensor during eager eager execution.
-
- Returns:
- A `Tensor` based on `value`.
-
- Raises:
- TypeError: If no conversion function is registered for `value`.
- RuntimeError: If a registered conversion function returns an invalid value.
-
- """
+ """Implementation of the public convert_to_tensor."""
if ctx is None: ctx = context.context()
if isinstance(value, EagerTensor):
if ctx.executing_eagerly():
@@ -2811,8 +2789,8 @@
self._stack_state_is_thread_local = False
self._thread_local = threading.local()
# Functions that will be applied to choose a device if none is specified.
- # After switch_to_thread_local(), self._thread_local._device_function_stack
- # is used instead.
+ # In TF2.x or after switch_to_thread_local(),
+ # self._thread_local._device_function_stack is used instead.
self._graph_device_function_stack = traceable_stack.TraceableStack()
# Default original_op applied to new ops.
self._default_original_op = None
@@ -2820,7 +2798,7 @@
# WhileContext defined in ops/control_flow_ops.py
self._control_flow_context = None
# A new node will depend of the union of all of the nodes in the stack.
- # After switch_to_thread_local(),
+ # In TF2.x or after switch_to_thread_local(),
# self._thread_local._control_dependencies_stack is used instead.
self._graph_control_dependencies_stack = []
# Arbitrary collections of objects.
@@ -2844,7 +2822,7 @@
producer=versions.GRAPH_DEF_VERSION,
min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)
self._building_function = False
- # Stack of colocate_with ops. After switch_to_thread_local(),
+ # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(),
# self._thread_local._colocation_stack is used instead.
self._graph_colocation_stack = traceable_stack.TraceableStack()
# Set of tensors that are dangerous to feed!
@@ -2877,6 +2855,8 @@
# requirement (many custom ops do not have shape functions, and we don't
# want to break these existing cases).
c_api.SetRequireShapeInferenceFns(self._c_graph, False)
+ if tf2.enabled():
+ self.switch_to_thread_local()
# Note: this method is private because the API of tf.Graph() is public and
# frozen, and this functionality is still not ready for public visibility.
@@ -5557,7 +5537,7 @@
app.run(main, argv)
-@tf_export("reset_default_graph")
+@tf_export(v1=["reset_default_graph"])
def reset_default_graph():
"""Clears the default graph stack and resets the global default graph.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index b9c6908..9c9ef79 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -503,7 +503,7 @@
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Graph is invalid, contains a cycle with 2 nodes"):
- sess.run(x)
+ self.evaluate(x)
def testUpdateInput(self):
g = ops.Graph()
@@ -557,7 +557,7 @@
errors.InvalidArgumentError,
"Input 0 of node add was passed string from Const_1:0 incompatible "
"with expected int32"):
- sess.run(z)
+ self.evaluate(z)
def testUpdateInputShapeError(self):
g = ops.Graph()
@@ -1075,6 +1075,13 @@
node { name: "FloatOutput" op: "FloatOutput" }
""", gd)
+ def testEagerBackingDevice(self):
+ with context.eager_mode():
+ with ops.device("/device:CPU:0"):
+ t = constant_op.constant(1.0)
+ self.assertRegexpMatches(t.device, "/device:CPU:0")
+ self.assertRegexpMatches(t.backing_device, "/device:CPU:0")
+
def testDevicePartialString(self):
g = ops.Graph()
with g.device("/job:worker/replica:2"):
@@ -2390,7 +2397,7 @@
c = math_ops.add(a, b)
# Create a session we can delete
with session.Session(graph=g) as sess:
- sess.run(c)
+ self.evaluate(c)
# Delete all references and trigger gc
del g
del a
@@ -2406,7 +2413,7 @@
math_ops.add([1, 2], [1, 2, 3])
a = constant_op.constant(1)
with session.Session() as sess:
- sess.run(a)
+ self.evaluate(a)
def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
g = ops.Graph()
@@ -2416,7 +2423,7 @@
test_ops.kernel_label_required(1)
a = constant_op.constant(1)
with session.Session() as sess:
- sess.run(a)
+ self.evaluate(a)
class AttrScopeTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index 0d20693..6b7f56a 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -82,8 +82,7 @@
return seeds
-@tf_export('random.set_random_seed',
- v1=['random.set_random_seed', 'set_random_seed'])
+@tf_export(v1=['random.set_random_seed', 'set_random_seed'])
def set_random_seed(seed):
"""Sets the graph-level random seed.
@@ -183,3 +182,103 @@
context.set_global_seed(seed)
else:
ops.get_default_graph().seed = seed
+
+
+@tf_export('random.set_seed', v1=[])
+def set_seed(seed):
+ """Sets the graph-level random seed.
+
+ Operations that rely on a random seed actually derive it from two seeds:
+ the graph-level and operation-level seeds. This sets the graph-level seed.
+
+ Its interactions with operation-level seeds is as follows:
+
+ 1. If neither the graph-level nor the operation seed is set:
+ A random seed is used for this op.
+ 2. If the graph-level seed is set, but the operation seed is not:
+ The system deterministically picks an operation seed in conjunction
+ with the graph-level seed so that it gets a unique random sequence.
+ 3. If the graph-level seed is not set, but the operation seed is set:
+ A default graph-level seed and the specified operation seed are used to
+ determine the random sequence.
+ 4. If both the graph-level and the operation seed are set:
+ Both seeds are used in conjunction to determine the random sequence.
+
+ To illustrate the user-visible effects, consider these examples:
+
+ To generate different sequences across sessions, set neither
+ graph-level nor op-level seeds:
+
+ ```python
+ a = tf.random_uniform([1])
+ b = tf.random_normal([1])
+
+ print("Session 1")
+ with tf.Session() as sess1:
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
+
+ print("Session 2")
+ with tf.Session() as sess2:
+ print(sess2.run(a)) # generates 'A3'
+ print(sess2.run(a)) # generates 'A4'
+ print(sess2.run(b)) # generates 'B3'
+ print(sess2.run(b)) # generates 'B4'
+ ```
+
+ To generate the same repeatable sequence for an op across sessions, set the
+ seed for the op:
+
+ ```python
+ a = tf.random_uniform([1], seed=1)
+ b = tf.random_normal([1])
+
+ # Repeatedly running this block with the same graph will generate the same
+ # sequence of values for 'a', but different sequences of values for 'b'.
+ print("Session 1")
+ with tf.Session() as sess1:
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
+
+ print("Session 2")
+ with tf.Session() as sess2:
+ print(sess2.run(a)) # generates 'A1'
+ print(sess2.run(a)) # generates 'A2'
+ print(sess2.run(b)) # generates 'B3'
+ print(sess2.run(b)) # generates 'B4'
+ ```
+
+ To make the random sequences generated by all ops be repeatable across
+ sessions, set a graph-level seed:
+
+ ```python
+ tf.random.set_seed(1234)
+ a = tf.random_uniform([1])
+ b = tf.random_normal([1])
+
+ # Repeatedly running this block with the same graph will generate the same
+ # sequences of 'a' and 'b'.
+ print("Session 1")
+ with tf.Session() as sess1:
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
+
+ print("Session 2")
+ with tf.Session() as sess2:
+ print(sess2.run(a)) # generates 'A1'
+ print(sess2.run(a)) # generates 'A2'
+ print(sess2.run(b)) # generates 'B1'
+ print(sess2.run(b)) # generates 'B2'
+ ```
+
+ Args:
+ seed: integer.
+ """
+ # TODO(go/tf2-random): change doc, update to match design doc
+ set_random_seed(seed)
diff --git a/tensorflow/python/framework/registry.py b/tensorflow/python/framework/registry.py
index 2e45acb..4357c76 100644
--- a/tensorflow/python/framework/registry.py
+++ b/tensorflow/python/framework/registry.py
@@ -23,10 +23,9 @@
from __future__ import division
from __future__ import print_function
-import traceback
-
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
+from tensorflow.python.util import tf_stack
# Registry mechanism below is based on mapreduce.python.mrpython.Register.
@@ -57,15 +56,17 @@
if name in self._registry:
(filename, line_number, function_name, _) = (
self._registry[name][_LOCATION_TAG])
- raise KeyError("Registering two %s with name '%s' !"
+ raise KeyError("Registering two %s with name '%s'! "
"(Previous registration was in %s %s:%d)" %
(self._name, name, function_name, filename, line_number))
logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
# stack trace is [this_function, Register(), user_function,...]
# so the user function is #2.
- stack = traceback.extract_stack()
- self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: stack[2]}
+ stack = tf_stack.extract_stack()
+ user_function = stack[2]
+ location_tag = tf_stack.convert_stack([user_function])[0]
+ self._registry[name] = {_TYPE_TAG: candidate, _LOCATION_TAG: location_tag}
def list(self):
"""Lists registered items.
diff --git a/tensorflow/python/framework/registry_test.py b/tensorflow/python/framework/registry_test.py
index a821e16..1a0d3f2 100644
--- a/tensorflow/python/framework/registry_test.py
+++ b/tensorflow/python/framework/registry_test.py
@@ -45,7 +45,9 @@
def testDuplicate(self):
myreg = registry.Registry('testbar')
myreg.register(bar, 'Bar')
- with self.assertRaises(KeyError):
+ with self.assertRaisesRegexp(
+ KeyError, r'Registering two testbar with name \'Bar\'! '
+ r'\(Previous registration was in [^ ]+ .*.py:[0-9]+\)'):
myreg.register(bar, 'Bar')
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index cab4268..5322204 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -66,9 +66,9 @@
self.assertTrue(c.op in d.op.control_inputs)
with self.cached_session() as sess:
- c_out = sess.run([c])
- n_out = sess.run([n])
- d_out = sess.run([d])
+ c_out = self.evaluate([c])
+ n_out = self.evaluate([n])
+ d_out = self.evaluate([d])
self.assertEqual(n_out, [-2])
self.assertEqual(c_out, [2])
@@ -145,8 +145,8 @@
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
with self.cached_session() as sess:
- c_out = sess.run([c])
- d_out = sess.run([d])
+ c_out = self.evaluate([c])
+ d_out = self.evaluate([d])
self.assertEqual(c_out, [42])
self.assertEqual(d_out, [11])
@@ -205,7 +205,7 @@
# Expect the three side effect graphs to have been evaluated.
with self.cached_session() as sess:
- sess.run([c_sub])
+ self.evaluate([c_sub])
self.assertIn('graph1', shared)
self.assertIn('graph2', shared)
self.assertIn('graph3', shared)
@@ -229,20 +229,20 @@
with self.cached_session() as sess:
# Initialize the variables first.
- sess.run([v1.initializer])
- sess.run([v2.initializer])
+ self.evaluate([v1.initializer])
+ self.evaluate([v2.initializer])
# Expect the side effects to be triggered when evaluating the add op as
# it will read the value of the variable.
- sess.run([add])
+ self.evaluate([add])
self.assertEqual(1, len(shared))
# Expect the side effect not to be triggered when evaluating the assign
# op as it will not access the 'read' output of the variable.
- sess.run([assign_v1])
+ self.evaluate([assign_v1])
self.assertEqual(1, len(shared))
- sess.run([add])
+ self.evaluate([add])
self.assertEqual(2, len(shared))
# Make sure the values read from the variable match the expected ones.
@@ -273,7 +273,7 @@
self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))
with self.cached_session() as sess:
- sess.run([reader])
+ self.evaluate([reader])
self.assertEqual(0, len(shared))
def testMultipleOutputs(self):
@@ -304,7 +304,7 @@
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
with self.cached_session() as sess:
- sess.run([neg])
+ self.evaluate([neg])
# All three ops have been processed.
self.assertEqual(3, len(shared))
@@ -375,7 +375,7 @@
self.assertIsNot(context(subscriptions[0]), context(subscriptions[1]))
with self.cached_session() as sess:
- sess.run(cond)
+ self.evaluate(cond)
self.assertEqual(3, len(results))
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 9db94f5..f98f301 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -371,8 +371,10 @@
(dtype.name, repr(mismatch), type(mismatch).__name__))
+# pylint: disable=invalid-name
@tf_export(v1=["make_tensor_proto"])
-def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False):
+def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False,
+ allow_broadcast=False):
"""Create a TensorProto.
Args:
@@ -380,6 +382,8 @@
dtype: Optional tensor_pb2 DataType value.
shape: List of integers representing the dimensions of tensor.
verify_shape: Boolean that enables verification of a shape of values.
+ allow_broadcast:Boolean that enables allowing scalars and 1 length vector
+ broadcasting. Cannot be true when verify_shape is true.
Returns:
A `TensorProto`. Depending on the type, it may contain data in the
@@ -416,6 +420,8 @@
can not have more elements than what "shape" specifies.
"""
+ if allow_broadcast and verify_shape:
+ raise ValueError("allow_broadcast and verify_shape are not both allowed.")
if isinstance(values, tensor_pb2.TensorProto):
return values
@@ -504,15 +510,22 @@
shape_size = np.prod(shape, dtype=np.int64)
is_same_size = shape_size == nparray.size
- if verify_shape:
- if not nparray.shape == tuple(shape):
+ if allow_broadcast:
+ if nparray.shape == (1,) or nparray.shape == tuple():
+ pass
+ elif nparray.size != shape_size:
raise TypeError("Expected Tensor's shape: %s, got %s." %
(tuple(shape), nparray.shape))
- if nparray.size > shape_size:
- raise ValueError(
- "Too many elements provided. Needed at most %d, but received %d" %
- (shape_size, nparray.size))
+ else:
+ if verify_shape and nparray.shape != tuple(shape):
+ raise TypeError("Expected Tensor's shape: %s, got %s." %
+ (tuple(shape), nparray.shape))
+
+ if nparray.size > shape_size:
+ raise ValueError(
+ "Too many elements provided. Needed at most %d, but received %d" %
+ (shape_size, nparray.size))
tensor_proto = tensor_pb2.TensorProto(
dtype=numpy_dtype.as_datatype_enum,
@@ -560,6 +573,7 @@
append_fn(tensor_proto, proto_values)
return tensor_proto
+# pylint: enable=invalid-name
@tf_export("make_ndarray")
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 8971227..fc97a22 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -892,8 +892,8 @@
"""Execute all test methods in the given class with and without eager."""
base_decorator = run_in_graph_and_eager_modes
for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith(
- "test") and not name.startswith("testSkipEager"):
+ if callable(value) and name.startswith("test") and not (
+ name.startswith("testSkipEager") or name.startswith("test_skip_eager")):
setattr(cls, name, base_decorator(value))
return cls
@@ -960,7 +960,7 @@
def decorator(f):
if tf_inspect.isclass(f):
raise ValueError(
- "`run_test_in_graph_and_eager_modes` only supports test methods. "
+ "`run_in_graph_and_eager_modes` only supports test methods. "
"Did you mean to use `run_all_in_graph_and_eager_modes`?")
def decorated(self, *args, **kwargs):
@@ -1005,6 +1005,38 @@
return decorator
+def run_deprecated_v1(func=None):
+ """Execute the decorated test in graph mode.
+
+ This function returns a decorator intended to be applied to tests that have
+ not been updated to a style that is compatible with both TensorFlow 1.x and
+ 2.x. When this decorated is applied, the test body will be run in
+ an environment where API calls construct graphs instead of executing eagerly.
+
+ Args:
+ func: function to be annotated. If `func` is None, this method returns a
+ decorator the can be applied to a function. If `func` is not None this
+ returns the decorator applied to `func`.
+ Returns:
+ Returns a decorator that will run the decorated test method in graph mode.
+ """
+
+ def decorator(f):
+ if tf_inspect.isclass(f):
+ raise ValueError("`run_deprecated_v1` only supports test methods.")
+
+ def decorated(self, *args, **kwargs):
+ with context.graph_mode():
+ f(self, *args, **kwargs)
+
+ return decorated
+
+ if func is not None:
+ return decorator(func)
+
+ return decorator
+
+
@tf_export("test.is_gpu_available")
def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
"""Returns whether TensorFlow can access a GPU.
@@ -1044,7 +1076,7 @@
return True
return False
except errors_impl.NotFoundError as e:
- if not all([x in str(e) for x in ["CUDA", "not find"]]):
+ if not all(x in str(e) for x in ["CUDA", "not find"]):
raise e
else:
logging.error(str(e))
@@ -1062,6 +1094,27 @@
yield
+@contextlib.contextmanager
+def use_gpu():
+ """Uses gpu when requested and available."""
+ with device(use_gpu=True):
+ yield
+
+
+@contextlib.contextmanager
+def force_gpu():
+ """Force the gpu to be used."""
+ with ops.device("/device:GPU:0"):
+ yield
+
+
+@contextlib.contextmanager
+def force_cpu():
+ """Force the cpu to be used."""
+ with ops.device("/device:CPU:0"):
+ yield
+
+
class CapturedWrites(object):
"""A utility class to load the captured writes made to a stream."""
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index cbefe86..2a37253 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -681,7 +681,7 @@
self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
def test_run_in_eager_and_graph_modes_test_class(self):
- msg = "`run_test_in_graph_and_eager_modes` only supports test methods.*"
+ msg = "`run_in_graph_and_eager_modes` only supports test methods.*"
with self.assertRaisesRegexp(ValueError, msg):
@test_util.run_in_graph_and_eager_modes()
class Foo(object):
diff --git a/tensorflow/python/grappler/cost_analyzer_test.py b/tensorflow/python/grappler/cost_analyzer_test.py
index b8225b8..de80df1 100644
--- a/tensorflow/python/grappler/cost_analyzer_test.py
+++ b/tensorflow/python/grappler/cost_analyzer_test.py
@@ -96,8 +96,8 @@
b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1))
y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc)
- cross_entropy = math_ops.reduce_mean(-math_ops.reduce_sum(
- label * math_ops.log(y_conv), reduction_indices=[1]))
+ cross_entropy = math_ops.reduce_mean(
+ -math_ops.reduce_sum(label * math_ops.log(y_conv), axis=[1]))
_ = adam.AdamOptimizer(1e-4).minimize(cross_entropy)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i
index daa5bc9..b746c3e 100644
--- a/tensorflow/python/grappler/tf_optimizer.i
+++ b/tensorflow/python/grappler/tf_optimizer.i
@@ -74,13 +74,13 @@
void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) {
tensorflow::SessionOptions options;
- std::vector<tensorflow::Device*> devices;
+ std::vector<std::unique_ptr<tensorflow::Device>> devices;
tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices);
if (!status.ok()) {
return;
}
- for (const tensorflow::Device* device : devices) {
+ for (const std::unique_ptr<tensorflow::Device>& device : devices) {
tensorflow::DeviceProperties& prop = (*device_map)[device->name()];
prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name());
@@ -88,7 +88,6 @@
// available device memory.
const tensorflow::DeviceAttributes& attr = device->attributes();
prop.set_memory_size(attr.memory_limit());
- delete device;
}
}
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index adfa226..69e18ea 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -122,8 +122,10 @@
"constraints.py",
"engine/__init__.py",
"engine/base_layer.py",
+ "engine/base_layer_utils.py",
"engine/distributed_training_utils.py",
"engine/input_layer.py",
+ "engine/input_spec.py",
"engine/network.py",
"engine/saving.py",
"engine/sequential.py",
@@ -141,6 +143,7 @@
"regularizers.py",
"utils/data_utils.py",
"utils/io_utils.py",
+ "utils/losses_utils.py",
],
srcs_version = "PY2AND3",
deps = [
@@ -182,7 +185,6 @@
":engine",
"//tensorflow/python:array_ops",
"//tensorflow/python:cudnn_rnn_ops_gen",
- "//tensorflow/python:distribute",
"//tensorflow/python:dtypes",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_ops",
@@ -196,6 +198,7 @@
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/distribute:distribute_lib",
],
)
@@ -266,6 +269,7 @@
name = "optimizers_test",
size = "medium",
srcs = ["optimizers_test.py"],
+ shard_count = 2,
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = [
@@ -273,6 +277,7 @@
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -302,6 +307,7 @@
":keras",
"//tensorflow/python:client_testlib",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -728,6 +734,7 @@
":keras",
"//tensorflow/python:client_testlib",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -741,6 +748,7 @@
":keras",
"//tensorflow/python:client_testlib",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index ac89621..c765464 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -2325,7 +2325,7 @@
else:
axis = 0
- if py_all([is_sparse(x) for x in tensors]):
+ if py_all(is_sparse(x) for x in tensors):
return sparse_ops.sparse_concat(axis, tensors)
else:
return array_ops.concat([to_dense(x) for x in tensors], axis)
@@ -3347,9 +3347,9 @@
assert not nest.is_sequence(input_t)
rank_diff = len(input_t.shape) - len(mask_t.shape)
for _ in range(rank_diff):
- mask_t = array_ops.expand_dims(mask_t)
- expand_dims = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
- return array_ops.tile(mask_t, expand_dims)
+ mask_t = array_ops.expand_dims(mask_t, -1)
+ multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
+ return array_ops.tile(mask_t, multiples)
if unroll:
if not time_steps:
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 0ab651b..48fdd56 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -1223,6 +1223,121 @@
for s, u_s in zip(additional_state_list[2], additional_state_list[3]):
self.assertAllClose(s, u_s, atol=1e-04)
+ def test_rnn_output_and_state_masking_independent(self):
+ num_samples = 2
+ num_timesteps = 4
+ state_and_io_size = 2
+ mask_last_num_timesteps = 2 # for second sample only
+
+ # a step function that just outputs inputs,
+ # but increments states +1 per timestep
+ def step_function(inputs, states):
+ return inputs, [s + 1 for s in states]
+
+ inputs_vals = np.random.random((num_samples, num_timesteps,
+ state_and_io_size))
+ initial_state_vals = np.random.random((num_samples, state_and_io_size))
+ # masking of two last timesteps for second sample only
+ mask_vals = np.ones((num_samples, num_timesteps))
+ mask_vals[1, -mask_last_num_timesteps:] = 0
+
+ # outputs expected to be same as inputs for the first sample
+ expected_outputs = inputs_vals.copy()
+ # but for the second sample all outputs in masked region should be the same
+ # as last output before masked region
+ expected_outputs[1, -mask_last_num_timesteps:] = \
+ expected_outputs[1, -(mask_last_num_timesteps + 1)]
+
+ expected_last_state = initial_state_vals.copy()
+ # first state should be incremented for every timestep (no masking)
+ expected_last_state[0] += num_timesteps
+ # second state should not be incremented for last two timesteps
+ expected_last_state[1] += (num_timesteps - mask_last_num_timesteps)
+
+ # verify same expected output for `unroll=true/false`
+ inputs = keras.backend.variable(inputs_vals)
+ initial_states = [keras.backend.variable(initial_state_vals)]
+ mask = keras.backend.variable(mask_vals)
+ for unroll in [True, False]:
+ _, outputs, last_states = keras.backend.rnn(
+ step_function,
+ inputs,
+ initial_states,
+ mask=mask,
+ unroll=unroll,
+ input_length=num_timesteps if unroll else None)
+
+ self.assertAllClose(keras.backend.eval(outputs), expected_outputs)
+ self.assertAllClose(
+ keras.backend.eval(last_states[0]), expected_last_state)
+
+ def test_rnn_output_num_dim_larger_than_2_masking(self):
+ num_samples = 3
+ num_timesteps = 4
+ num_features = 5
+
+ def step_function(inputs, states):
+ outputs = keras.backend.tile(keras.backend.expand_dims(inputs), [1, 1, 2])
+ return outputs, [keras.backend.identity(s) for s in states]
+ # Note: cannot just return states (which can be a problem) ->
+ # tensorflow/python/ops/resource_variable_ops.py", line 824, in set_shape
+ # NotImplementedError: ResourceVariable does not implement set_shape()
+
+ inputs_vals = np.random.random((num_samples, num_timesteps, num_features))
+ initial_state_vals = np.random.random((num_samples, 6))
+ mask_vals = np.ones((num_samples, num_timesteps))
+ mask_vals[-1, -1] = 0 # final timestep masked for last sample
+
+ expected_outputs = np.repeat(inputs_vals[..., None], repeats=2, axis=-1)
+ # for the last sample, the final timestep (in masked region) should be the
+ # same as the second to final output (before masked region)
+ expected_outputs[-1, -1] = expected_outputs[-1, -2]
+
+ inputs = keras.backend.variable(inputs_vals)
+ initial_states = [keras.backend.variable(initial_state_vals)]
+ mask = keras.backend.variable(mask_vals)
+ for unroll in [True, False]:
+ _, outputs, _ = keras.backend.rnn(
+ step_function,
+ inputs,
+ initial_states,
+ mask=mask,
+ unroll=unroll,
+ input_length=num_timesteps if unroll else None)
+
+ self.assertAllClose(keras.backend.eval(outputs), expected_outputs)
+
+ def test_rnn_state_num_dim_larger_than_2_masking(self):
+ num_samples = 3
+ num_timesteps = 4
+
+ def step_function(inputs, states):
+ return inputs, [s + 1 for s in states]
+
+ inputs_vals = np.random.random((num_samples, num_timesteps, 5))
+ initial_state_vals = np.random.random((num_samples, 6, 7))
+ mask_vals = np.ones((num_samples, num_timesteps))
+ mask_vals[0, -2:] = 0 # final two timesteps masked for first sample
+
+ expected_last_state = initial_state_vals.copy()
+ expected_last_state[0] += (num_timesteps - 2)
+ expected_last_state[1:] += num_timesteps
+
+ inputs = keras.backend.variable(inputs_vals)
+ initial_states = [keras.backend.variable(initial_state_vals)]
+ mask = keras.backend.variable(mask_vals)
+ for unroll in [True, False]:
+ _, _, last_states = keras.backend.rnn(
+ step_function,
+ inputs,
+ initial_states,
+ mask=mask,
+ unroll=unroll,
+ input_length=num_timesteps if unroll else None)
+
+ self.assertAllClose(
+ keras.backend.eval(last_states[0]), expected_last_state)
+
def test_normalize_batch_in_training(self):
val = np.random.random((10, 3, 10, 10))
x = keras.backend.variable(val)
diff --git a/tensorflow/python/keras/engine/__init__.py b/tensorflow/python/keras/engine/__init__.py
index 26aed34..005f646 100644
--- a/tensorflow/python/keras/engine/__init__.py
+++ b/tensorflow/python/keras/engine/__init__.py
@@ -20,10 +20,10 @@
# TODO(fchollet): Remove hourglass imports once external code is done importing
# non-public APIs.
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils.layer_utils import get_source_inputs
del absolute_import
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 5dcbc4d..8b79593 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -18,8 +18,6 @@
from __future__ import division
from __future__ import print_function
-import collections as collections_lib
-import enum # pylint: disable=g-bad-import-order
import functools
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
@@ -36,13 +34,14 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
+from tensorflow.python.keras.engine import base_layer_utils
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
# A module that only depends on `keras.layers` import these from here.
from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -54,28 +53,6 @@
from tensorflow.tools.docs import doc_controls
-class CallConvention(enum.Enum):
- """Calling conventions for passing `Layer` inputs to `Layer.call`."""
- # The Layer takes inputs as its first argument, named "inputs" for
- # compatibility with the signature of Layer.__call__. This is the mode assumed
- # for Layers which are not subclassed Models.
- EXPLICIT_INPUTS_ARGUMENT = 1
- # The Layer takes a single positional argument, not named "inputs". It's
- # treated like an "inputs" argument.
- SINGLE_POSITIONAL_ARGUMENT = 2
- # The Layer has multiple positional arguments to which its inputs should be
- # bound.
- POSITIONAL_ARGUMENTS_ARE_INPUTS = 3
-
-
-def _create_mean_metric(value, name=None):
- # TODO(psv): Remove this import when b/110718070 is fixed.
- from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
- metric_obj = metrics_module.Mean(name=name)
- result = metric_obj(value)
- return metric_obj, result
-
-
@tf_export('keras.layers.Layer')
class Layer(checkpointable.CheckpointableBase):
"""Base layer class.
@@ -110,10 +87,6 @@
name: The name of the layer (string).
dtype: Default dtype of the layer's weights (default of `None` means use the
type of the first input).
- trainable_variables: List of trainable variables.
- non_trainable_variables: List of non-trainable variables.
- variables: List of all variables of this layer, trainable and
- non-trainable.
updates: List of update ops of this layer.
losses: List of losses added by this layer.
trainable_weights: List of variables to be included in backprop.
@@ -158,9 +131,9 @@
self.built = False
# Provides information about which inputs are compatible with the layer.
self.input_spec = None
+ self.supports_masking = False
self._init_set_name(name)
-
self._activity_regularizer = kwargs.pop('activity_regularizer', None)
self._trainable_weights = []
self._non_trainable_weights = []
@@ -189,15 +162,14 @@
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
hasattr(self, 'compute_mask'))
- self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT
+ self._call_convention = (base_layer_utils
+ .CallConvention.EXPLICIT_INPUTS_ARGUMENT)
# These lists will be filled via successive calls
# to self._add_inbound_node().
self._inbound_nodes = []
self._outbound_nodes = []
- self.supports_masking = False
-
call_argspec = tf_inspect.getfullargspec(self.call)
if 'training' in call_argspec.args:
self._expects_training_arg = True
@@ -227,345 +199,22 @@
else:
self._initial_weights = None
- def _init_set_name(self, name, zero_based=True):
- if not name:
- self._name = unique_layer_name(
- generic_utils.to_snake_case(self.__class__.__name__),
- zero_based=zero_based)
- else:
- self._name = name
-
- @property
- def dtype(self):
- return self._dtype
-
- @property
- def name(self):
- return self._name
-
- @property
- def activity_regularizer(self):
- """Optional regularizer function for the output of this layer."""
- return self._activity_regularizer
-
- @activity_regularizer.setter
- def activity_regularizer(self, regularizer):
- """Optional regularizer function for the output of this layer."""
- self._activity_regularizer = self._no_dependency(regularizer)
-
- @property
- def trainable_weights(self):
- return self._trainable_weights if self.trainable else []
-
- @property
- def non_trainable_weights(self):
- if self.trainable:
- return self._non_trainable_weights
- else:
- return self._trainable_weights + self._non_trainable_weights
-
- @property
- def trainable_variables(self):
- return self.trainable_weights
-
- @property
- def non_trainable_variables(self):
- return self.non_trainable_weights
-
- @property
- def weights(self):
- """Returns the list of all layer variables/weights.
-
- Returns:
- A list of variables.
- """
- return self.trainable_weights + self.non_trainable_weights
-
- @property
- def variables(self):
- """Returns the list of all layer variables/weights.
-
- Returns:
- A list of variables.
- """
- return self.weights
-
- @property
- def updates(self):
- if not self.trainable and not self.stateful:
- return []
- return self._updates
-
- @doc_controls.for_subclass_implementers
- def add_update(self, updates, inputs=None):
- """Add update op(s), potentially dependent on layer inputs.
-
- Weight updates (for instance, the updates of the moving mean and variance
- in a BatchNormalization layer) may be dependent on the inputs passed
- when calling a layer. Hence, when reusing the same layer on
- different inputs `a` and `b`, some entries in `layer.updates` may be
- dependent on `a` and some on `b`. This method automatically keeps track
- of dependencies.
-
- The `get_updates_for` method allows to retrieve the updates relevant to a
- specific set of inputs.
-
- This call is ignored when eager execution is enabled (in that case, variable
- updates are run on the fly and thus do not need to be tracked for later
- execution).
-
- Arguments:
- updates: Update op, or list/tuple of update ops.
- inputs: If anything other than None is passed, it signals the updates
- are conditional on some of the layer's inputs,
- and thus they should only be run where these inputs are available.
- This is the case for BatchNormalization updates, for instance.
- If None, the updates will be taken into account unconditionally,
- and you are responsible for making sure that any dependency they might
- have is available at runtime.
- A step counter might fall into this category.
- """
- if context.executing_eagerly():
- return # Updates already applied when in eager mode.
-
- def process_update(x):
- if isinstance(x, ops.Operation):
- return x
- elif hasattr(x, 'op'):
- return x.op
- else:
- return ops.convert_to_tensor(x)
-
- updates = generic_utils.to_list(updates)
- updates = [process_update(x) for x in updates]
- self._updates += updates
- if inputs is None:
- for u in updates:
- u._unconditional_update = True # pylint: disable=protected-access
- else:
- for u in updates:
- u._unconditional_update = False # pylint: disable=protected-access
-
- def get_updates_for(self, inputs):
- """Retrieves updates relevant to a specific set of inputs.
-
- Arguments:
- inputs: Input tensor or list/tuple of input tensors.
-
- Returns:
- List of update ops of the layer that depend on `inputs`.
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- # Updates disabled if layer is not trainable and not explicitly stateful.
- if not self.trainable and not self.stateful:
- return []
-
- if inputs is None:
- # Requesting unconditional updates.
- return [x for x in self.updates if x._unconditional_update] # pylint: disable=protected-access
-
- # Requesting input-conditional updates.
- inputs = nest.flatten(inputs)
- reachable = tf_utils.get_reachable_from_inputs(inputs, self.updates)
- updates = []
- for update in self.updates:
- if update in reachable:
- updates.append(update)
- return updates
-
- @property
- def losses(self):
- """Losses which are associated with this `Layer`.
-
- Variable regularization tensors are created when this property is accessed,
- so it is eager safe: accessing `losses` under a `tf.GradientTape` will
- propagate gradients back to the corresponding variables.
-
- Returns:
- A list of tensors.
- """
- collected_losses = []
- if context.executing_eagerly():
- collected_losses.extend(self._eager_losses)
- else:
- collected_losses.extend(self._losses)
- for regularizer in self._callable_losses:
- loss_tensor = regularizer()
- if loss_tensor is not None:
- collected_losses.append(loss_tensor)
- return collected_losses
-
- @doc_controls.for_subclass_implementers
- def add_loss(self, losses, inputs=None):
- """Add loss tensor(s), potentially dependent on layer inputs.
-
- Some losses (for instance, activity regularization losses) may be dependent
- on the inputs passed when calling a layer. Hence, when reusing the same
- layer on different inputs `a` and `b`, some entries in `layer.losses` may
- be dependent on `a` and some on `b`. This method automatically keeps track
- of dependencies.
-
- The `get_losses_for` method allows to retrieve the losses relevant to a
- specific set of inputs.
-
- Note that `add_loss` is not supported when executing eagerly. Instead,
- variable regularizers may be added through `add_variable`. Activity
- regularization is not supported directly (but such losses may be returned
- from `Layer.call()`).
-
- Arguments:
- losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
- may also be zero-argument callables which create a loss tensor.
- inputs: Ignored when executing eagerly. If anything other than None is
- passed, it signals the losses are conditional on some of the layer's
- inputs, and thus they should only be run where these inputs are
- available. This is the case for activity regularization losses, for
- instance. If `None` is passed, the losses are assumed
- to be unconditional, and will apply across all dataflows of the layer
- (e.g. weight regularization losses).
- """
- losses = generic_utils.to_list(losses)
-
- def _tag_unconditional(loss):
- if callable(loss):
- loss = loss()
- if loss is None:
- return None # Will be filtered out when computing the .losses property
- if not tensor_util.is_tensor(loss):
- loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
- loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
- return loss
-
- for loss in losses:
- if callable(loss):
- self._callable_losses.append(
- functools.partial(_tag_unconditional, loss))
- else:
- if context.executing_eagerly():
- self._eager_losses.append(_tag_unconditional(loss))
- else:
- self._losses.append(_tag_unconditional(loss))
-
- @doc_controls.for_subclass_implementers
- def add_metric(self, value, aggregation=None, name=None):
- """Adds metric tensor to the layer.
-
- Args:
- value: Metric tensor.
- aggregation: Sample-wise metric reduction function. If `aggregation=None`,
- it indicates that the metric tensor provided has been aggregated
- already. eg, `model.add_metric(BinaryAccuracy(name='acc')(y_true,
- y_pred))`. If aggregation='mean', the given metric tensor will be
- sample-wise reduced using `mean` function. eg, `model.add_metric(
- tf.reduce_mean(outputs), name='output_mean', aggregation='mean')`.
- name: String metric name.
-
- Raises:
- ValueError: If `aggregation` is anything other than None or `mean`.
- """
- if aggregation is not None and aggregation != 'mean':
- raise ValueError(
- 'We currently support only `mean` sample-wise metric aggregation. '
- 'You provided aggregation=`%s`' % aggregation)
-
- if tf_utils.is_symbolic_tensor(value):
- self._symbolic_add_metric(value, aggregation, name)
- else:
- self._eager_add_metric(value, aggregation, name)
-
- def _get_existing_metric(self, name=None):
- match = [m for m in self._metrics if m.name == name]
- if not match:
- return
- if len(match) > 1:
- raise ValueError(
- 'Please provide different names for the metrics you have added. '
- 'We found {} metrics with the name: "{}"'.format(len(match), name))
- return match[0]
-
- def _eager_add_metric(self, value, aggregation=None, name=None):
- # If the given metric is available in `metrics` list we just update state
- # on it, otherwise we create a new metric instance and
- # add it to the `metrics` list.
- match = self._get_existing_metric(name)
- if match:
- match(value) # Update the metric state.
- return
- else:
- if aggregation is None:
- raise ValueError('We do not support adding an aggregated metric tensor '
- 'in `call` in eager execution.')
- metric_obj, _ = _create_mean_metric(value, name)
- self._metrics.append(metric_obj)
-
- def _symbolic_add_metric(self, value, aggregation=None, name=None):
- if aggregation is None:
- # Iterate over the metrics and check if the given metric exists already.
- # This can happen when a metric instance is created in subclassed model
- # layer `__init__` and we have tracked that instance already in
- # model.__setattr__.
- match = self._get_existing_metric(name)
- if match:
- result_tensor = value
- if match.name not in self._metrics_tensors:
- self._metrics_tensors[match.name] = result_tensor
- return
- else:
- raise ValueError(
- 'We currently do not support reusing a metric instance.')
- else:
- # We track the instance using the metadata on the result tensor.
- result_tensor = value
- metric_obj = result_tensor._metric_obj
- else:
- # If a non-aggregated tensor is given as input (ie. `aggregation` is
- # explicitly set to `mean`), we wrap the tensor in `Mean` metric.
- metric_obj, result_tensor = _create_mean_metric(value, name)
- self._metrics.append(metric_obj)
- self._metrics_tensors[metric_obj.name] = result_tensor
-
- def get_losses_for(self, inputs):
- """Retrieves losses relevant to a specific set of inputs.
-
- Arguments:
- inputs: Input tensor or list/tuple of input tensors.
-
- Returns:
- List of loss tensors of the layer that depend on `inputs`.
-
- Raises:
- RuntimeError: If called in Eager mode.
- """
- if inputs is None:
- # Requesting unconditional losses.
- return [x for x in self.losses if x._unconditional_loss] # pylint: disable=protected-access
-
- # Requesting input-conditional losses.
- inputs = nest.flatten(inputs)
- # Retrieve the set of tensors in the TF graph that depend on `inputs`.
- # The losses we want to return will be part of this set.
- # To avoid unnecessary work, we stop the search in case all of
- # `self.losses` have been retrieved.
- reachable = tf_utils.get_reachable_from_inputs(inputs, self.losses)
- losses = []
- for loss in self.losses:
- if loss in reachable:
- losses.append(loss)
- return losses
-
- def _name_scope(self):
- return self.name
-
def build(self, input_shape):
"""Creates the variables of the layer."""
self.built = True
@doc_controls.for_subclass_implementers
- def add_variable(self, *args, **kwargs):
- """Alias for `add_weight`."""
- return self.add_weight(*args, **kwargs)
+ def call(self, inputs, **kwargs): # pylint: disable=unused-argument
+ """This is where the layer's logic lives.
+
+ Arguments:
+ inputs: Input tensor, or list/tuple of input tensors.
+ **kwargs: Additional keyword arguments.
+
+ Returns:
+ A tensor or list/tuple of tensors.
+ """
+ return inputs
@doc_controls.for_subclass_implementers
def add_weight(self,
@@ -668,7 +317,7 @@
shape=shape,
# TODO(allenl): a `make_variable` equivalent should be added as a
# `Checkpointable` method.
- getter=getter or make_variable,
+ getter=getter or base_layer_utils.make_variable,
# Manage errors in Layer rather than Checkpointable.
overwrite=True,
initializer=initializer,
@@ -694,341 +343,45 @@
self._non_trainable_weights.append(variable)
return variable
- def _handle_weight_regularization(self, name, variable, regularizer):
- """Create lambdas which compute regularization losses."""
+ def get_config(self):
+ """Returns the config of the layer.
- def _loss_for_variable(v):
- """Creates a regularization loss `Tensor` for variable `v`."""
- with ops.colocate_with(v):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- return regularization
+ A layer config is a Python dictionary (serializable)
+ containing the configuration of a layer.
+ The same layer can be reinstantiated later
+ (without its trained weights) from this configuration.
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- self.add_loss(functools.partial(_loss_for_variable, v))
- else:
- self.add_loss(functools.partial(_loss_for_variable, variable))
+ The config of a layer does not include connectivity
+ information, nor the layer class name. These are handled
+ by `Network` (one layer of abstraction above).
- def _handle_activity_regularization(self, inputs, outputs):
- # Apply activity regularization.
- # Note that it should be applied every time the layer creates a new
- # output, since it is output-specific.
- if self._activity_regularizer:
- output_list = nest.flatten(outputs)
- with ops.name_scope('ActivityRegularizer'):
- for output in output_list:
- activity_loss = self._activity_regularizer(output)
- batch_size = math_ops.cast(
- array_ops.shape(output)[0], activity_loss.dtype)
- # Make activity regularization strength batch-agnostic.
- mean_activity_loss = activity_loss / batch_size
- self.add_loss(mean_activity_loss, inputs=inputs)
+ Returns:
+ Python dictionary.
+ """
+ config = {'name': self.name, 'trainable': self.trainable}
+ if hasattr(self, '_batch_input_shape'):
+ config['batch_input_shape'] = self._batch_input_shape
+ if hasattr(self, 'dtype'):
+ config['dtype'] = self.dtype
+ return config
- @doc_controls.for_subclass_implementers
- def call(self, inputs, **kwargs): # pylint: disable=unused-argument
- """This is where the layer's logic lives.
+ @classmethod
+ def from_config(cls, config):
+ """Creates a layer from its config.
+
+ This method is the reverse of `get_config`,
+ capable of instantiating the same layer from the config
+ dictionary. It does not handle layer connectivity
+ (handled by Network), nor weights (handled by `set_weights`).
Arguments:
- inputs: Input tensor, or list/tuple of input tensors.
- **kwargs: Additional keyword arguments.
+ config: A Python dictionary, typically the
+ output of get_config.
Returns:
- A tensor or list/tuple of tensors.
+ A layer instance.
"""
- return inputs
-
- def __call__(self, inputs, *args, **kwargs):
- """Wraps `call`, applying pre- and post-processing steps.
-
- Arguments:
- inputs: input tensor(s).
- *args: additional positional arguments to be passed to `self.call`.
- **kwargs: additional keyword arguments to be passed to `self.call`.
-
- Returns:
- Output tensor(s).
-
- Note:
- - The following optional keyword arguments are reserved for specific uses:
- * `training`: Boolean scalar tensor of Python boolean indicating
- whether the `call` is meant for training or inference.
- * `mask`: Boolean input mask.
- - If the layer's `call` method takes a `mask` argument (as some Keras
- layers do), its default value will be set to the mask generated
- for `inputs` by the previous layer (if `input` did come from
- a layer that generated a corresponding mask, i.e. if it came from
- a Keras layer with masking support.
-
- Raises:
- ValueError: if the layer's `call` method returns None (an invalid value).
- """
- input_list = nest.flatten(inputs)
-
- if context.executing_eagerly():
- # Accept NumPy inputs by converting to Tensors when executing eagerly.
- if all([isinstance(x, (np.ndarray, float, int)) for x in input_list]):
- inputs = nest.map_structure(ops.convert_to_tensor, inputs)
- input_list = nest.flatten(inputs)
-
- # We will attempt to build a TF graph if & only if all inputs are symbolic.
- # This is always the case in graph mode. It can also be the case in eager
- # mode when all inputs can be traced back to `keras.Input()` (when building
- # models using the functional API).
- build_graph = tf_utils.are_all_symbolic_tensors(input_list)
- executing_eagerly = context.executing_eagerly()
-
- # Handle Keras mask propagation from previous layer to current layer.
- previous_mask = None
- if build_graph and (not hasattr(self, '_compute_previous_mask') or
- self._compute_previous_mask):
- previous_mask = collect_previous_mask(inputs)
- if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = self._no_dependency(
- function_utils.fn_args(self.call))
- if ('mask' in self._call_fn_args and 'mask' not in kwargs and
- not generic_utils.is_all_none(previous_mask)):
- # The previous layer generated a mask, and mask was not explicitly pass
- # to __call__, hence we set previous_mask as the default value.
- kwargs['mask'] = previous_mask
-
- input_shapes = None
-
- with ops.name_scope(self._name_scope()):
- if not self.built:
- # Check input assumptions set before layer building, e.g. input rank.
- self._assert_input_compatibility(inputs)
- if input_list and self._dtype is None:
- try:
- self._dtype = input_list[0].dtype.base_dtype.name
- except AttributeError:
- pass
-
- if all(hasattr(x, 'shape') for x in input_list):
- input_shapes = nest.map_structure(lambda x: x.shape, inputs)
-
- if (not hasattr(self, '_is_graph_network') or
- self.__class__.__name__ == 'Sequential' or
- not hasattr(self.build, '_is_default')):
- # Only if self is a layer, an instance of a sequential model, or
- # the user has manually overwritten the build method do we need to
- # build it.
- self.build(input_shapes)
- # We must set self.built since user defined build functions are not
- # constrained to set self.built.
- self.built = True
-
- # Check input assumptions set after layer building, e.g. input shape.
- if build_graph:
- # Symbolic execution on symbolic tensors. We will attempt to build
- # the corresponding TF subgraph inside `backend.get_graph()`
- self._assert_input_compatibility(inputs)
- graph = backend.get_graph()
- with graph.as_default():
- if not executing_eagerly:
- # In graph mode, failure to build the layer's graph
- # implies a user-side bug. We don't catch exceptions.
- outputs = self.call(inputs, *args, **kwargs)
- else:
- try:
- outputs = self.call(inputs, *args, **kwargs)
- except Exception: # pylint: disable=broad-except
- # Any issue during graph-building means we will later run the
- # model in eager mode, whether the issue was related to
- # graph mode or not. This provides a nice debugging experience.
- self._call_is_graph_friendly = False
- # We will use static shape inference to return symbolic tensors
- # matching the specifications of the layer outputs.
- # Since we have set `self._call_is_graph_friendly = False`,
- # we will never attempt to run the underlying TF graph (which is
- # disconnected).
- # TODO(fchollet): consider py_func as an alternative, which
- # would enable us to run the underlying graph if needed.
- input_shapes = nest.map_structure(lambda x: x.shape, inputs)
- output_shapes = self.compute_output_shape(input_shapes)
- outputs = nest.map_structure(
- lambda shape: backend.placeholder(shape, dtype=self.dtype),
- output_shapes)
-
- if outputs is None:
- raise ValueError('A layer\'s `call` method should return a '
- 'Tensor or a list of Tensors, not None '
- '(layer: ' + self.name + ').')
- self._handle_activity_regularization(inputs, outputs)
- self._set_mask_metadata(inputs, outputs, previous_mask)
- if have_all_keras_metadata(inputs):
- inputs, outputs = self._set_connectivity_metadata_(
- inputs, outputs, args, kwargs)
- if hasattr(self, '_set_inputs') and not self.inputs:
- # Subclassed network: explicitly set metadata normally set by
- # a call to self._set_inputs().
- # This is not relevant in eager execution.
- self._set_inputs(inputs, outputs)
- else:
- # Eager execution on data tensors.
- outputs = self.call(inputs, *args, **kwargs)
- self._handle_activity_regularization(inputs, outputs)
- return outputs
-
- if not context.executing_eagerly():
- # Optionally load weight values specified at layer instantiation.
- # TODO(fchollet): consider enabling this with eager execution too.
- if (hasattr(self, '_initial_weights') and
- self._initial_weights is not None):
- self.set_weights(self._initial_weights)
- del self._initial_weights
- return outputs
-
- def apply(self, inputs, *args, **kwargs):
- """Apply the layer on a input.
-
- This simply wraps `self.__call__`.
-
- Arguments:
- inputs: Input tensor(s).
- *args: additional positional arguments to be passed to `self.call`.
- **kwargs: additional keyword arguments to be passed to `self.call`.
-
- Returns:
- Output tensor(s).
- """
- return self.__call__(inputs, *args, **kwargs)
-
- def _set_mask_metadata(self, inputs, outputs, previous_mask):
- # In some cases the mask of the outputs has already been computed by
- # inner layers and does not need to be recomputed by this layer.
- mask_already_computed = all(
- hasattr(x, '_keras_mask') for x in generic_utils.to_list(outputs))
- if hasattr(self, 'compute_mask') and not mask_already_computed:
- output_mask = self.compute_mask(inputs, previous_mask)
- else:
- output_mask = None
- if isinstance(outputs, (list, tuple)):
- if output_mask is None:
- output_mask = [None for _ in range(len(outputs))]
- for x, m in zip(outputs, output_mask):
- try:
- x._keras_mask = m # pylint: disable=protected-access
- except AttributeError:
- pass # C type such as dict. Masking not supported in this case.
- else:
- try:
- outputs._keras_mask = output_mask # pylint: disable=protected-access
- except AttributeError:
- pass # C type such as dict. Masking not supported in this case.
-
- def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
- call_convention = getattr(self, '_call_convention',
- CallConvention.EXPLICIT_INPUTS_ARGUMENT)
- if args:
- if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT:
- raise TypeError(
- 'This layer ("{}") takes an `inputs` argument in `call()`, '
- 'and only the `inputs` argument may be specified as a positional '
- 'argument. Pass everything else as a keyword argument '
- '(those arguments will not be tracked '
- 'as inputs to the layer).'.format(self.name))
- elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT:
- raise TypeError(
- 'This layer ("{}") takes a single positional argument in `call()`,'
- ' which is by convention the `inputs` argument, '
- 'and only this argument may be specified as a positional argument. '
- 'Pass everything else as a keyword argument '
- '(those arguments will not be tracked '
- 'as inputs to the layer).'.format(self.name))
-
- # If the layer returns tensors from its inputs, unmodified,
- # we copy them to avoid loss of tensor metadata.
- output_ls = nest.flatten(outputs)
- output_ls_copy = []
- for x in output_ls:
- if x in nest.flatten(inputs):
- with ops.name_scope(self.name):
- x = array_ops.identity(x)
- output_ls_copy.append(x)
- if len(output_ls_copy) == 1:
- outputs = output_ls_copy[0]
- else:
- outputs = output_ls_copy
-
- inputs, kwargs = self._inputs_from_call_args(
- call_args=(inputs,) + args, call_kwargs=kwargs)
- # Add an inbound node to the layer, so it can keep track of this call.
- # This updates the layer history of the output tensor(s).
- kwargs.pop('mask', None) # `mask` should not be serialized.
- self._add_inbound_node(
- input_tensors=inputs, output_tensors=outputs, arguments=kwargs)
- return inputs, outputs
-
- def _inputs_from_call_args(self, call_args, call_kwargs):
- """Get Layer inputs from __call__ *args and **kwargs.
-
- Args:
- call_args: The positional arguments passed to __call__.
- call_kwargs: The keyword argument dict passed to __call__.
-
- Returns:
- A tuple of (inputs, non_input_kwargs). These may be the same objects as
- were passed in (call_args and call_kwargs).
- """
- call_convention = getattr(self, '_call_convention',
- CallConvention.EXPLICIT_INPUTS_ARGUMENT)
- if (call_convention in (
- CallConvention.EXPLICIT_INPUTS_ARGUMENT,
- CallConvention.SINGLE_POSITIONAL_ARGUMENT)):
- assert len(call_args) == 1 # TypeError raised earlier in __call__.
- return call_args[0], call_kwargs
- else:
- call_arg_spec = tf_inspect.getfullargspec(self.call)
- # There is no explicit "inputs" argument expected or provided to
- # call(). Arguments which have default values are considered non-inputs,
- # and arguments without are considered inputs.
- if call_arg_spec.defaults:
- if call_arg_spec.varargs is not None:
- raise TypeError(
- 'Layers may not accept both positional arguments and '
- 'arguments with default values (unable to determine which '
- 'are inputs to the layer). '
- 'Issue occurred with layer "%s"' % (self.name))
- keyword_arg_names = set(
- call_arg_spec.args[-len(call_arg_spec.defaults):])
- else:
- keyword_arg_names = set()
- # Training is never an input argument name, to allow signatures like
- # call(x, training).
- keyword_arg_names.add('training')
- _, unwrapped_call = tf_decorator.unwrap(self.call)
- bound_args = inspect.getcallargs(
- unwrapped_call, *call_args, **call_kwargs)
- if call_arg_spec.varkw is not None:
- var_kwargs = bound_args.pop(call_arg_spec.varkw)
- bound_args.update(var_kwargs)
- keyword_arg_names = keyword_arg_names.union(var_kwargs.keys())
- all_args = call_arg_spec.args
- if all_args and bound_args[all_args[0]] is self:
- # Ignore the 'self' argument of methods
- bound_args.pop(call_arg_spec.args[0])
- all_args = all_args[1:]
- non_input_arg_values = {}
- input_arg_values = []
- remaining_args_are_keyword = False
- for argument_name in all_args:
- if argument_name in keyword_arg_names:
- remaining_args_are_keyword = True
- else:
- if remaining_args_are_keyword:
- raise TypeError(
- 'Found a positional argument in a layer call after a non-input '
- 'argument. All arguments after "training" must be keyword '
- 'arguments, and are not tracked as inputs to the layer. '
- 'Issue occurred with layer "%s"' % (self.name))
- if remaining_args_are_keyword:
- non_input_arg_values[argument_name] = bound_args[argument_name]
- else:
- input_arg_values.append(bound_args[argument_name])
- if call_arg_spec.varargs is not None:
- input_arg_values.extend(bound_args[call_arg_spec.varargs])
- return input_arg_values, non_input_arg_values
+ return cls(**config)
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer.
@@ -1057,10 +410,11 @@
graph = func_graph.FuncGraph('graph')
with graph.as_default():
if isinstance(input_shape, list):
- inputs = [generate_placeholders_from_shape(shape)
+ inputs = [base_layer_utils.generate_placeholders_from_shape(shape)
for shape in input_shape]
else:
- inputs = generate_placeholders_from_shape(input_shape)
+ inputs = base_layer_utils.generate_placeholders_from_shape(
+ input_shape)
try:
if self._expects_training_arg:
@@ -1105,86 +459,442 @@
# carry over the input mask
return mask
- def _add_inbound_node(self,
- input_tensors,
- output_tensors,
- arguments=None):
- """Internal method to create an inbound node for the layer.
+ def __call__(self, inputs, *args, **kwargs):
+ """Wraps `call`, applying pre- and post-processing steps.
Arguments:
- input_tensors: list of input tensors.
- output_tensors: list of output tensors.
- arguments: dictionary of keyword arguments that were passed to the
- `call` method of the layer at the call that created the node.
- """
- input_tensors = nest.flatten(input_tensors)
- output_tensors = nest.flatten(output_tensors)
-
- # Collect input tensor(s) coordinates.
- inbound_layers = []
- node_indices = []
- tensor_indices = []
- for x in input_tensors:
- assert hasattr(x, '_keras_history')
- inbound_layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
- inbound_layers.append(inbound_layer)
- node_indices.append(node_index)
- tensor_indices.append(tensor_index)
-
- # Create node, add it to inbound nodes.
- Node(
- self,
- inbound_layers=inbound_layers,
- node_indices=node_indices,
- tensor_indices=tensor_indices,
- input_tensors=input_tensors,
- output_tensors=output_tensors,
- arguments=arguments)
-
- # Update tensor history metadata.
- for i in range(len(output_tensors)):
- # The metadata attribute consists of 1) a layer instance
- # 2) a node index for the layer, 3) a tensor index for the node.
- # The allows layer reuse (multiple nodes per layer) and multi-output
- # or multi-input layers (e.g. a layer can return multiple tensors,
- # and each can be sent to a different layer).
- output_tensors[i]._keras_history = (self, len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access
-
- def _get_node_attribute_at_index(self, node_index, attr, attr_name):
- """Private utility to retrieves an attribute (e.g. inputs) from a node.
-
- This is used to implement the methods:
- - get_input_shape_at
- - get_output_shape_at
- - get_input_at
- etc...
-
- Arguments:
- node_index: Integer index of the node from which
- to retrieve the attribute.
- attr: Exact node attribute name.
- attr_name: Human-readable attribute name, for error messages.
+ inputs: input tensor(s).
+ *args: additional positional arguments to be passed to `self.call`.
+ **kwargs: additional keyword arguments to be passed to `self.call`.
Returns:
- The layer's attribute `attr` at the node of index `node_index`.
+ Output tensor(s).
+
+ Note:
+ - The following optional keyword arguments are reserved for specific uses:
+ * `training`: Boolean scalar tensor of Python boolean indicating
+ whether the `call` is meant for training or inference.
+ * `mask`: Boolean input mask.
+ - If the layer's `call` method takes a `mask` argument (as some Keras
+ layers do), its default value will be set to the mask generated
+ for `inputs` by the previous layer (if `input` did come from
+ a layer that generated a corresponding mask, i.e. if it came from
+ a Keras layer with masking support.
Raises:
- RuntimeError: If the layer has no inbound nodes, or if called in Eager
- mode.
- ValueError: If the index provided does not match any node.
+ ValueError: if the layer's `call` method returns None (an invalid value).
"""
- if not self._inbound_nodes:
- raise RuntimeError('The layer has never been called '
- 'and thus has no defined ' + attr_name + '.')
- if not len(self._inbound_nodes) > node_index:
- raise ValueError('Asked to get ' + attr_name + ' at node ' +
- str(node_index) + ', but the layer has only ' +
- str(len(self._inbound_nodes)) + ' inbound nodes.')
- values = getattr(self._inbound_nodes[node_index], attr)
- if len(values) == 1:
- return values[0]
+ input_list = nest.flatten(inputs)
+
+ if context.executing_eagerly():
+ # Accept NumPy inputs by converting to Tensors when executing eagerly.
+ if all(isinstance(x, (np.ndarray, float, int)) for x in input_list):
+ inputs = nest.map_structure(ops.convert_to_tensor, inputs)
+ input_list = nest.flatten(inputs)
+
+ # We will attempt to build a TF graph if & only if all inputs are symbolic.
+ # This is always the case in graph mode. It can also be the case in eager
+ # mode when all inputs can be traced back to `keras.Input()` (when building
+ # models using the functional API).
+ build_graph = tf_utils.are_all_symbolic_tensors(input_list)
+ executing_eagerly = context.executing_eagerly()
+
+ # Handle Keras mask propagation from previous layer to current layer.
+ previous_mask = None
+ if build_graph and (not hasattr(self, '_compute_previous_mask') or
+ self._compute_previous_mask):
+ previous_mask = base_layer_utils.collect_previous_mask(inputs)
+ if not hasattr(self, '_call_fn_args'):
+ self._call_fn_args = self._no_dependency(
+ function_utils.fn_args(self.call))
+ if ('mask' in self._call_fn_args and 'mask' not in kwargs and
+ not generic_utils.is_all_none(previous_mask)):
+ # The previous layer generated a mask, and mask was not explicitly pass
+ # to __call__, hence we set previous_mask as the default value.
+ kwargs['mask'] = previous_mask
+
+ input_shapes = None
+
+ with ops.name_scope(self._name_scope()):
+ if not self.built:
+ # Check input assumptions set before layer building, e.g. input rank.
+ input_spec.assert_input_compatibility(
+ self.input_spec, inputs, self.name)
+ if input_list and self._dtype is None:
+ try:
+ self._dtype = input_list[0].dtype.base_dtype.name
+ except AttributeError:
+ pass
+
+ if all(hasattr(x, 'shape') for x in input_list):
+ input_shapes = nest.map_structure(lambda x: x.shape, inputs)
+
+ if (not hasattr(self, '_is_graph_network') or
+ self.__class__.__name__ == 'Sequential' or
+ not hasattr(self.build, '_is_default')):
+ # Only if self is a layer, an instance of a sequential model, or
+ # the user has manually overwritten the build method do we need to
+ # build it.
+ self.build(input_shapes)
+ # We must set self.built since user defined build functions are not
+ # constrained to set self.built.
+ self.built = True
+
+ # Check input assumptions set after layer building, e.g. input shape.
+ if build_graph:
+ # Symbolic execution on symbolic tensors. We will attempt to build
+ # the corresponding TF subgraph inside `backend.get_graph()`
+ input_spec.assert_input_compatibility(
+ self.input_spec, inputs, self.name)
+ graph = backend.get_graph()
+ with graph.as_default():
+ if not executing_eagerly:
+ # In graph mode, failure to build the layer's graph
+ # implies a user-side bug. We don't catch exceptions.
+ outputs = self.call(inputs, *args, **kwargs)
+ else:
+ try:
+ outputs = self.call(inputs, *args, **kwargs)
+ except Exception: # pylint: disable=broad-except
+ # Any issue during graph-building means we will later run the
+ # model in eager mode, whether the issue was related to
+ # graph mode or not. This provides a nice debugging experience.
+ self._call_is_graph_friendly = False
+ # We will use static shape inference to return symbolic tensors
+ # matching the specifications of the layer outputs.
+ # Since we have set `self._call_is_graph_friendly = False`,
+ # we will never attempt to run the underlying TF graph (which is
+ # disconnected).
+ # TODO(fchollet): consider py_func as an alternative, which
+ # would enable us to run the underlying graph if needed.
+ input_shapes = nest.map_structure(lambda x: x.shape, inputs)
+ output_shapes = self.compute_output_shape(input_shapes)
+ outputs = nest.map_structure(
+ lambda shape: backend.placeholder(shape, dtype=self.dtype),
+ output_shapes)
+
+ if outputs is None:
+ raise ValueError('A layer\'s `call` method should return a '
+ 'Tensor or a list of Tensors, not None '
+ '(layer: ' + self.name + ').')
+ self._handle_activity_regularization(inputs, outputs)
+ self._set_mask_metadata(inputs, outputs, previous_mask)
+ if base_layer_utils.have_all_keras_metadata(inputs):
+ inputs, outputs = self._set_connectivity_metadata_(
+ inputs, outputs, args, kwargs)
+ if hasattr(self, '_set_inputs') and not self.inputs:
+ # Subclassed network: explicitly set metadata normally set by
+ # a call to self._set_inputs().
+ # This is not relevant in eager execution.
+ self._set_inputs(inputs, outputs)
+ else:
+ # Eager execution on data tensors.
+ outputs = self.call(inputs, *args, **kwargs)
+ self._handle_activity_regularization(inputs, outputs)
+ return outputs
+
+ if not context.executing_eagerly():
+ # Optionally load weight values specified at layer instantiation.
+ # TODO(fchollet): consider enabling this with eager execution too.
+ if (hasattr(self, '_initial_weights') and
+ self._initial_weights is not None):
+ self.set_weights(self._initial_weights)
+ del self._initial_weights
+ return outputs
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def activity_regularizer(self):
+ """Optional regularizer function for the output of this layer."""
+ return self._activity_regularizer
+
+ @activity_regularizer.setter
+ def activity_regularizer(self, regularizer):
+ """Optional regularizer function for the output of this layer."""
+ self._activity_regularizer = self._no_dependency(regularizer)
+
+ @property
+ def trainable_weights(self):
+ return self._trainable_weights if self.trainable else []
+
+ @property
+ def non_trainable_weights(self):
+ if self.trainable:
+ return self._non_trainable_weights
else:
- return values
+ return self._trainable_weights + self._non_trainable_weights
+
+ @property
+ def weights(self):
+ """Returns the list of all layer variables/weights.
+
+ Returns:
+ A list of variables.
+ """
+ return self.trainable_weights + self.non_trainable_weights
+
+ @property
+ def updates(self):
+ if not self.trainable and not self.stateful:
+ return []
+ return self._updates
+
+ @property
+ def losses(self):
+ """Losses which are associated with this `Layer`.
+
+ Variable regularization tensors are created when this property is accessed,
+ so it is eager safe: accessing `losses` under a `tf.GradientTape` will
+ propagate gradients back to the corresponding variables.
+
+ Returns:
+ A list of tensors.
+ """
+ collected_losses = []
+ if context.executing_eagerly():
+ collected_losses.extend(self._eager_losses)
+ else:
+ collected_losses.extend(self._losses)
+ for regularizer in self._callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ collected_losses.append(loss_tensor)
+ return collected_losses
+
+ @doc_controls.for_subclass_implementers
+ def add_loss(self, losses, inputs=None):
+ """Add loss tensor(s), potentially dependent on layer inputs.
+
+ Some losses (for instance, activity regularization losses) may be dependent
+ on the inputs passed when calling a layer. Hence, when reusing the same
+ layer on different inputs `a` and `b`, some entries in `layer.losses` may
+ be dependent on `a` and some on `b`. This method automatically keeps track
+ of dependencies.
+
+ The `get_losses_for` method allows to retrieve the losses relevant to a
+ specific set of inputs.
+
+ Note that `add_loss` is not supported when executing eagerly. Instead,
+ variable regularizers may be added through `add_variable`. Activity
+ regularization is not supported directly (but such losses may be returned
+ from `Layer.call()`).
+
+ Arguments:
+ losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
+ may also be zero-argument callables which create a loss tensor.
+ inputs: Ignored when executing eagerly. If anything other than None is
+ passed, it signals the losses are conditional on some of the layer's
+ inputs, and thus they should only be run where these inputs are
+ available. This is the case for activity regularization losses, for
+ instance. If `None` is passed, the losses are assumed
+ to be unconditional, and will apply across all dataflows of the layer
+ (e.g. weight regularization losses).
+ """
+ losses = generic_utils.to_list(losses)
+
+ def _tag_unconditional(loss):
+ if callable(loss):
+ loss = loss()
+ if loss is None:
+ return None # Will be filtered out when computing the .losses property
+ if not tensor_util.is_tensor(loss):
+ loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
+ loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
+ return loss
+
+ for loss in losses:
+ if callable(loss):
+ self._callable_losses.append(
+ functools.partial(_tag_unconditional, loss))
+ else:
+ if context.executing_eagerly():
+ self._eager_losses.append(_tag_unconditional(loss))
+ else:
+ self._losses.append(_tag_unconditional(loss))
+
+ @doc_controls.for_subclass_implementers
+ def add_metric(self, value, aggregation=None, name=None):
+ """Adds metric tensor to the layer.
+
+ Args:
+ value: Metric tensor.
+ aggregation: Sample-wise metric reduction function. If `aggregation=None`,
+ it indicates that the metric tensor provided has been aggregated
+ already. eg, `model.add_metric(BinaryAccuracy(name='acc')(y_true,
+ y_pred))`. If aggregation='mean', the given metric tensor will be
+ sample-wise reduced using `mean` function. eg, `model.add_metric(
+ tf.reduce_mean(outputs), name='output_mean', aggregation='mean')`.
+ name: String metric name.
+
+ Raises:
+ ValueError: If `aggregation` is anything other than None or `mean`.
+ """
+ if aggregation is not None and aggregation != 'mean':
+ raise ValueError(
+ 'We currently support only `mean` sample-wise metric aggregation. '
+ 'You provided aggregation=`%s`' % aggregation)
+
+ if tf_utils.is_symbolic_tensor(value):
+ self._symbolic_add_metric(value, aggregation, name)
+ else:
+ self._eager_add_metric(value, aggregation, name)
+
+ @doc_controls.for_subclass_implementers
+ def add_update(self, updates, inputs=None):
+ """Add update op(s), potentially dependent on layer inputs.
+
+ Weight updates (for instance, the updates of the moving mean and variance
+ in a BatchNormalization layer) may be dependent on the inputs passed
+ when calling a layer. Hence, when reusing the same layer on
+ different inputs `a` and `b`, some entries in `layer.updates` may be
+ dependent on `a` and some on `b`. This method automatically keeps track
+ of dependencies.
+
+ The `get_updates_for` method allows to retrieve the updates relevant to a
+ specific set of inputs.
+
+ This call is ignored when eager execution is enabled (in that case, variable
+ updates are run on the fly and thus do not need to be tracked for later
+ execution).
+
+ Arguments:
+ updates: Update op, or list/tuple of update ops.
+ inputs: If anything other than None is passed, it signals the updates
+ are conditional on some of the layer's inputs,
+ and thus they should only be run where these inputs are available.
+ This is the case for BatchNormalization updates, for instance.
+ If None, the updates will be taken into account unconditionally,
+ and you are responsible for making sure that any dependency they might
+ have is available at runtime.
+ A step counter might fall into this category.
+ """
+ if context.executing_eagerly():
+ return # Updates already applied when in eager mode.
+
+ def process_update(x):
+ if isinstance(x, ops.Operation):
+ return x
+ elif hasattr(x, 'op'):
+ return x.op
+ else:
+ return ops.convert_to_tensor(x)
+
+ updates = generic_utils.to_list(updates)
+ updates = [process_update(x) for x in updates]
+ self._updates += updates
+ if inputs is None:
+ for u in updates:
+ u._unconditional_update = True # pylint: disable=protected-access
+ else:
+ for u in updates:
+ u._unconditional_update = False # pylint: disable=protected-access
+
+ def set_weights(self, weights):
+ """Sets the weights of the layer, from Numpy arrays.
+
+ Arguments:
+ weights: a list of Numpy arrays. The number
+ of arrays and their shape must match
+ number of the dimensions of the weights
+ of the layer (i.e. it should match the
+ output of `get_weights`).
+
+ Raises:
+ ValueError: If the provided weights list does not match the
+ layer's specifications.
+ """
+ params = self.weights
+ if len(params) != len(weights):
+ raise ValueError('You called `set_weights(weights)` on layer "' +
+ self.name + '" with a weight list of length ' +
+ str(len(weights)) + ', but the layer was expecting ' +
+ str(len(params)) + ' weights. Provided weights: ' +
+ str(weights)[:50] + '...')
+ if not params:
+ return
+ weight_value_tuples = []
+ param_values = backend.batch_get_value(params)
+ for pv, p, w in zip(param_values, params, weights):
+ if pv.shape != w.shape:
+ raise ValueError('Layer weight shape ' + str(pv.shape) +
+ ' not compatible with '
+ 'provided weight shape ' + str(w.shape))
+ weight_value_tuples.append((p, w))
+ backend.batch_set_value(weight_value_tuples)
+
+ def get_weights(self):
+ """Returns the current weights of the layer.
+
+ Returns:
+ Weights values as a list of numpy arrays.
+ """
+ params = self.weights
+ return backend.batch_get_value(params)
+
+ def get_updates_for(self, inputs):
+ """Retrieves updates relevant to a specific set of inputs.
+
+ Arguments:
+ inputs: Input tensor or list/tuple of input tensors.
+
+ Returns:
+ List of update ops of the layer that depend on `inputs`.
+
+ Raises:
+ RuntimeError: If called in Eager mode.
+ """
+ # Updates disabled if layer is not trainable and not explicitly stateful.
+ if not self.trainable and not self.stateful:
+ return []
+
+ if inputs is None:
+ # Requesting unconditional updates.
+ return [x for x in self.updates if x._unconditional_update] # pylint: disable=protected-access
+
+ # Requesting input-conditional updates.
+ inputs = nest.flatten(inputs)
+ reachable = tf_utils.get_reachable_from_inputs(inputs, self.updates)
+ updates = []
+ for update in self.updates:
+ if update in reachable:
+ updates.append(update)
+ return updates
+
+ def get_losses_for(self, inputs):
+ """Retrieves losses relevant to a specific set of inputs.
+
+ Arguments:
+ inputs: Input tensor or list/tuple of input tensors.
+
+ Returns:
+ List of loss tensors of the layer that depend on `inputs`.
+
+ Raises:
+ RuntimeError: If called in Eager mode.
+ """
+ if inputs is None:
+ # Requesting unconditional losses.
+ return [x for x in self.losses if x._unconditional_loss] # pylint: disable=protected-access
+
+ # Requesting input-conditional losses.
+ inputs = nest.flatten(inputs)
+ # Retrieve the set of tensors in the TF graph that depend on `inputs`.
+ # The losses we want to return will be part of this set.
+ # To avoid unnecessary work, we stop the search in case all of
+ # `self.losses` have been retrieved.
+ reachable = tf_utils.get_reachable_from_inputs(inputs, self.losses)
+ losses = []
+ for loss in self.losses:
+ if loss in reachable:
+ losses.append(loss)
+ return losses
def get_input_mask_at(self, node_index):
"""Retrieves the input mask tensor(s) of a layer at a given node.
@@ -1439,8 +1149,7 @@
', but the layer isn\'t built. '
'You can build it manually via: `' + self.name +
'.build(batch_input_shape)`.')
- weight_shapes = [w.shape.as_list() for w in self.weights]
- return int(sum([np.prod(w) for w in weight_shapes]))
+ return int(sum(np.prod(w.shape.as_list()) for w in self.weights))
@property
def output_shape(self):
@@ -1492,182 +1201,367 @@
"""Deprecated, do NOT use! Only for compatibility with external Keras."""
return self._outbound_nodes
- def _assert_input_compatibility(self, inputs):
- """Checks compatibility between the layer and provided inputs.
+ ##############################################################################
+ # Methods & attributes below are public aliases of other methods. #
+ ##############################################################################
- This checks that the tensor(s) `inputs` verify the input assumptions
- of the layer (if any). If not, a clear and actional exception gets raised.
+ def apply(self, inputs, *args, **kwargs):
+ """Apply the layer on a input.
+
+ This is an alias of `self.__call__`.
Arguments:
- inputs: input tensor or list of input tensors.
+ inputs: Input tensor(s).
+ *args: additional positional arguments to be passed to `self.call`.
+ **kwargs: additional keyword arguments to be passed to `self.call`.
- Raises:
- ValueError: in case of mismatch between
- the provided inputs and the expectations of the layer.
+ Returns:
+ Output tensor(s).
"""
- if not self.input_spec:
- return
- if not isinstance(self.input_spec, (list, tuple)):
- input_spec = nest.flatten(self.input_spec)
+ return self.__call__(inputs, *args, **kwargs)
+
+ @doc_controls.for_subclass_implementers
+ def add_variable(self, *args, **kwargs):
+ """Alias for `add_weight`."""
+ return self.add_weight(*args, **kwargs)
+
+ @property
+ def variables(self):
+ """Returns the list of all layer variables/weights.
+
+ Alias of `self.weights`.
+
+ Returns:
+ A list of variables.
+ """
+ return self.weights
+
+ @property
+ def trainable_variables(self):
+ return self.trainable_weights
+
+ @property
+ def non_trainable_variables(self):
+ return self.non_trainable_weights
+
+ ##############################################################################
+ # Methods & attributes below are all private and only used by the framework. #
+ ##############################################################################
+
+ def _name_scope(self):
+ return self.name
+
+ def _init_set_name(self, name, zero_based=True):
+ if not name:
+ self._name = base_layer_utils.unique_layer_name(
+ generic_utils.to_snake_case(self.__class__.__name__),
+ zero_based=zero_based)
else:
- input_spec = self.input_spec
- inputs = nest.flatten(inputs)
- if len(inputs) != len(input_spec):
- raise ValueError('Layer ' + self.name + ' expects ' +
- str(len(input_spec)) + ' inputs, '
- 'but it received ' + str(len(inputs)) +
- ' input tensors. Inputs received: ' + str(inputs))
- for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
- if spec is None:
- continue
+ self._name = name
- if (spec.ndim is not None or
- spec.min_ndim is not None or
- spec.max_ndim is not None):
- if x.shape.ndims is None:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- 'its rank is undefined, but the layer requires a '
- 'defined rank.')
+ def _get_existing_metric(self, name=None):
+ match = [m for m in self._metrics if m.name == name]
+ if not match:
+ return
+ if len(match) > 1:
+ raise ValueError(
+ 'Please provide different names for the metrics you have added. '
+ 'We found {} metrics with the name: "{}"'.format(len(match), name))
+ return match[0]
- # Check ndim.
- if spec.ndim is not None:
- ndim = x.shape.ndims
- if ndim != spec.ndim:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- 'expected ndim=' + str(spec.ndim) + ', found ndim=' +
- str(ndim) + '. Full shape received: ' +
- str(x.shape.as_list()))
- if spec.max_ndim is not None:
- ndim = x.shape.ndims
- if ndim is not None and ndim > spec.max_ndim:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- 'expected max_ndim=' + str(spec.max_ndim) +
- ', found ndim=' + str(ndim))
- if spec.min_ndim is not None:
- ndim = x.shape.ndims
- if ndim is not None and ndim < spec.min_ndim:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- ': expected min_ndim=' + str(spec.min_ndim) +
- ', found ndim=' + str(ndim) +
- '. Full shape received: ' +
- str(x.shape.as_list()))
- # Check dtype.
- if spec.dtype is not None:
- if x.dtype != spec.dtype:
- raise ValueError('Input ' + str(input_index) + ' of layer ' +
- self.name + ' is incompatible with the layer: '
- 'expected dtype=' + str(spec.dtype) +
- ', found dtype=' + str(x.dtype))
- # Check specific shape axes.
- if spec.axes:
- shape = x.shape.as_list()
- if shape is not None:
- for axis, value in spec.axes.items():
- if hasattr(value, 'value'):
- value = value.value
- if value is not None and shape[int(axis)] not in {value, None}:
- raise ValueError(
- 'Input ' + str(input_index) + ' of layer ' + self.name + ' is'
- ' incompatible with the layer: expected axis ' + str(axis) +
- ' of input shape to have value ' + str(value) +
- ' but received input with shape ' + str(shape))
- # Check shape.
- if spec.shape is not None:
- shape = x.shape.as_list()
- if shape is not None:
- for spec_dim, dim in zip(spec.shape, shape):
- if spec_dim is not None and dim is not None:
- if spec_dim != dim:
- raise ValueError('Input ' + str(input_index) +
- ' is incompatible with layer ' + self.name +
- ': expected shape=' + str(spec.shape) +
- ', found shape=' + str(shape))
+ def _eager_add_metric(self, value, aggregation=None, name=None):
+ # If the given metric is available in `metrics` list we just update state
+ # on it, otherwise we create a new metric instance and
+ # add it to the `metrics` list.
+ match = self._get_existing_metric(name)
+ if match:
+ match(value) # Update the metric state.
+ return
+ else:
+ if aggregation is None:
+ raise ValueError('We do not support adding an aggregated metric tensor '
+ 'in `call` in eager execution.')
+ metric_obj, _ = base_layer_utils.create_mean_metric(value, name)
+ self._metrics.append(metric_obj)
- def set_weights(self, weights):
- """Sets the weights of the layer, from Numpy arrays.
+ def _symbolic_add_metric(self, value, aggregation=None, name=None):
+ if aggregation is None:
+ # Iterate over the metrics and check if the given metric exists already.
+ # This can happen when a metric instance is created in subclassed model
+ # layer `__init__` and we have tracked that instance already in
+ # model.__setattr__.
+ match = self._get_existing_metric(name)
+ if match:
+ result_tensor = value
+ if match.name not in self._metrics_tensors:
+ self._metrics_tensors[match.name] = result_tensor
+ return
+ else:
+ raise ValueError(
+ 'We currently do not support reusing a metric instance.')
+ else:
+ # We track the instance using the metadata on the result tensor.
+ result_tensor = value
+ metric_obj = result_tensor._metric_obj
+ else:
+ # If a non-aggregated tensor is given as input (ie. `aggregation` is
+ # explicitly set to `mean`), we wrap the tensor in `Mean` metric.
+ metric_obj, result_tensor = base_layer_utils.create_mean_metric(
+ value, name)
+ self._metrics.append(metric_obj)
+ self._metrics_tensors[metric_obj.name] = result_tensor
+
+ def _handle_weight_regularization(self, name, variable, regularizer):
+ """Create lambdas which compute regularization losses."""
+
+ def _loss_for_variable(v):
+ """Creates a regularization loss `Tensor` for variable `v`."""
+ with ops.colocate_with(v):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ return regularization
+
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ self.add_loss(functools.partial(_loss_for_variable, v))
+ else:
+ self.add_loss(functools.partial(_loss_for_variable, variable))
+
+ def _handle_activity_regularization(self, inputs, outputs):
+ # Apply activity regularization.
+ # Note that it should be applied every time the layer creates a new
+ # output, since it is output-specific.
+ if self._activity_regularizer:
+ output_list = nest.flatten(outputs)
+ with ops.name_scope('ActivityRegularizer'):
+ for output in output_list:
+ activity_loss = self._activity_regularizer(output)
+ batch_size = math_ops.cast(
+ array_ops.shape(output)[0], activity_loss.dtype)
+ # Make activity regularization strength batch-agnostic.
+ mean_activity_loss = activity_loss / batch_size
+ self.add_loss(mean_activity_loss, inputs=inputs)
+
+ def _set_mask_metadata(self, inputs, outputs, previous_mask):
+ # In some cases the mask of the outputs has already been computed by
+ # inner layers and does not need to be recomputed by this layer.
+ mask_already_computed = all(
+ hasattr(x, '_keras_mask') for x in generic_utils.to_list(outputs))
+ if hasattr(self, 'compute_mask') and not mask_already_computed:
+ output_mask = self.compute_mask(inputs, previous_mask)
+ else:
+ output_mask = None
+ if isinstance(outputs, (list, tuple)):
+ if output_mask is None:
+ output_mask = [None for _ in range(len(outputs))]
+ for x, m in zip(outputs, output_mask):
+ try:
+ x._keras_mask = m # pylint: disable=protected-access
+ except AttributeError:
+ pass # C type such as dict. Masking not supported in this case.
+ else:
+ try:
+ outputs._keras_mask = output_mask # pylint: disable=protected-access
+ except AttributeError:
+ pass # C type such as dict. Masking not supported in this case.
+
+ def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
+ call_convention = getattr(
+ self, '_call_convention',
+ base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+ if args:
+ if call_convention == (base_layer_utils
+ .CallConvention.EXPLICIT_INPUTS_ARGUMENT):
+ raise TypeError(
+ 'This layer ("{}") takes an `inputs` argument in `call()`, '
+ 'and only the `inputs` argument may be specified as a positional '
+ 'argument. Pass everything else as a keyword argument '
+ '(those arguments will not be tracked '
+ 'as inputs to the layer).'.format(self.name))
+ elif call_convention == (base_layer_utils
+ .CallConvention.SINGLE_POSITIONAL_ARGUMENT):
+ raise TypeError(
+ 'This layer ("{}") takes a single positional argument in `call()`,'
+ ' which is by convention the `inputs` argument, '
+ 'and only this argument may be specified as a positional argument. '
+ 'Pass everything else as a keyword argument '
+ '(those arguments will not be tracked '
+ 'as inputs to the layer).'.format(self.name))
+
+ # If the layer returns tensors from its inputs, unmodified,
+ # we copy them to avoid loss of tensor metadata.
+ output_ls = nest.flatten(outputs)
+ output_ls_copy = []
+ for x in output_ls:
+ if x in nest.flatten(inputs):
+ with ops.name_scope(self.name):
+ x = array_ops.identity(x)
+ output_ls_copy.append(x)
+ if len(output_ls_copy) == 1:
+ outputs = output_ls_copy[0]
+ else:
+ outputs = output_ls_copy
+
+ inputs, kwargs = self._inputs_from_call_args(
+ call_args=(inputs,) + args, call_kwargs=kwargs)
+ # Add an inbound node to the layer, so it can keep track of this call.
+ # This updates the layer history of the output tensor(s).
+ kwargs.pop('mask', None) # `mask` should not be serialized.
+ self._add_inbound_node(
+ input_tensors=inputs, output_tensors=outputs, arguments=kwargs)
+ return inputs, outputs
+
+ def _inputs_from_call_args(self, call_args, call_kwargs):
+ """Get Layer inputs from __call__ *args and **kwargs.
+
+ Args:
+ call_args: The positional arguments passed to __call__.
+ call_kwargs: The keyword argument dict passed to __call__.
+
+ Returns:
+ A tuple of (inputs, non_input_kwargs). These may be the same objects as
+ were passed in (call_args and call_kwargs).
+ """
+ call_convention = getattr(
+ self, '_call_convention',
+ base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT)
+ if (call_convention in (
+ base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT,
+ base_layer_utils.CallConvention.SINGLE_POSITIONAL_ARGUMENT)):
+ assert len(call_args) == 1 # TypeError raised earlier in __call__.
+ return call_args[0], call_kwargs
+ else:
+ call_arg_spec = tf_inspect.getfullargspec(self.call)
+ # There is no explicit "inputs" argument expected or provided to
+ # call(). Arguments which have default values are considered non-inputs,
+ # and arguments without are considered inputs.
+ if call_arg_spec.defaults:
+ if call_arg_spec.varargs is not None:
+ raise TypeError(
+ 'Layers may not accept both positional arguments and '
+ 'arguments with default values (unable to determine which '
+ 'are inputs to the layer). '
+ 'Issue occurred with layer "%s"' % (self.name))
+ keyword_arg_names = set(
+ call_arg_spec.args[-len(call_arg_spec.defaults):])
+ else:
+ keyword_arg_names = set()
+ # Training is never an input argument name, to allow signatures like
+ # call(x, training).
+ keyword_arg_names.add('training')
+ _, unwrapped_call = tf_decorator.unwrap(self.call)
+ bound_args = inspect.getcallargs(
+ unwrapped_call, *call_args, **call_kwargs)
+ if call_arg_spec.varkw is not None:
+ var_kwargs = bound_args.pop(call_arg_spec.varkw)
+ bound_args.update(var_kwargs)
+ keyword_arg_names = keyword_arg_names.union(var_kwargs.keys())
+ all_args = call_arg_spec.args
+ if all_args and bound_args[all_args[0]] is self:
+ # Ignore the 'self' argument of methods
+ bound_args.pop(call_arg_spec.args[0])
+ all_args = all_args[1:]
+ non_input_arg_values = {}
+ input_arg_values = []
+ remaining_args_are_keyword = False
+ for argument_name in all_args:
+ if argument_name in keyword_arg_names:
+ remaining_args_are_keyword = True
+ else:
+ if remaining_args_are_keyword:
+ raise TypeError(
+ 'Found a positional argument in a layer call after a non-input '
+ 'argument. All arguments after "training" must be keyword '
+ 'arguments, and are not tracked as inputs to the layer. '
+ 'Issue occurred with layer "%s"' % (self.name))
+ if remaining_args_are_keyword:
+ non_input_arg_values[argument_name] = bound_args[argument_name]
+ else:
+ input_arg_values.append(bound_args[argument_name])
+ if call_arg_spec.varargs is not None:
+ input_arg_values.extend(bound_args[call_arg_spec.varargs])
+ return input_arg_values, non_input_arg_values
+
+ def _add_inbound_node(self,
+ input_tensors,
+ output_tensors,
+ arguments=None):
+ """Internal method to create an inbound node for the layer.
Arguments:
- weights: a list of Numpy arrays. The number
- of arrays and their shape must match
- number of the dimensions of the weights
- of the layer (i.e. it should match the
- output of `get_weights`).
+ input_tensors: list of input tensors.
+ output_tensors: list of output tensors.
+ arguments: dictionary of keyword arguments that were passed to the
+ `call` method of the layer at the call that created the node.
+ """
+ input_tensors = nest.flatten(input_tensors)
+ output_tensors = nest.flatten(output_tensors)
+
+ # Collect input tensor(s) coordinates.
+ inbound_layers = []
+ node_indices = []
+ tensor_indices = []
+ for x in input_tensors:
+ assert hasattr(x, '_keras_history')
+ inbound_layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access
+ inbound_layers.append(inbound_layer)
+ node_indices.append(node_index)
+ tensor_indices.append(tensor_index)
+
+ # Create node, add it to inbound nodes.
+ Node(
+ self,
+ inbound_layers=inbound_layers,
+ node_indices=node_indices,
+ tensor_indices=tensor_indices,
+ input_tensors=input_tensors,
+ output_tensors=output_tensors,
+ arguments=arguments)
+
+ # Update tensor history metadata.
+ for i in range(len(output_tensors)):
+ # The metadata attribute consists of 1) a layer instance
+ # 2) a node index for the layer, 3) a tensor index for the node.
+ # The allows layer reuse (multiple nodes per layer) and multi-output
+ # or multi-input layers (e.g. a layer can return multiple tensors,
+ # and each can be sent to a different layer).
+ output_tensors[i]._keras_history = (self, len(self._inbound_nodes) - 1, i) # pylint: disable=protected-access
+
+ def _get_node_attribute_at_index(self, node_index, attr, attr_name):
+ """Private utility to retrieves an attribute (e.g. inputs) from a node.
+
+ This is used to implement the methods:
+ - get_input_shape_at
+ - get_output_shape_at
+ - get_input_at
+ etc...
+
+ Arguments:
+ node_index: Integer index of the node from which
+ to retrieve the attribute.
+ attr: Exact node attribute name.
+ attr_name: Human-readable attribute name, for error messages.
+
+ Returns:
+ The layer's attribute `attr` at the node of index `node_index`.
Raises:
- ValueError: If the provided weights list does not match the
- layer's specifications.
+ RuntimeError: If the layer has no inbound nodes, or if called in Eager
+ mode.
+ ValueError: If the index provided does not match any node.
"""
- params = self.weights
- if len(params) != len(weights):
- raise ValueError('You called `set_weights(weights)` on layer "' +
- self.name + '" with a weight list of length ' +
- str(len(weights)) + ', but the layer was expecting ' +
- str(len(params)) + ' weights. Provided weights: ' +
- str(weights)[:50] + '...')
- if not params:
- return
- weight_value_tuples = []
- param_values = backend.batch_get_value(params)
- for pv, p, w in zip(param_values, params, weights):
- if pv.shape != w.shape:
- raise ValueError('Layer weight shape ' + str(pv.shape) +
- ' not compatible with '
- 'provided weight shape ' + str(w.shape))
- weight_value_tuples.append((p, w))
- backend.batch_set_value(weight_value_tuples)
-
- def get_weights(self):
- """Returns the current weights of the layer.
-
- Returns:
- Weights values as a list of numpy arrays.
- """
- params = self.weights
- return backend.batch_get_value(params)
-
- def get_config(self):
- """Returns the config of the layer.
-
- A layer config is a Python dictionary (serializable)
- containing the configuration of a layer.
- The same layer can be reinstantiated later
- (without its trained weights) from this configuration.
-
- The config of a layer does not include connectivity
- information, nor the layer class name. These are handled
- by `Network` (one layer of abstraction above).
-
- Returns:
- Python dictionary.
- """
- config = {'name': self.name, 'trainable': self.trainable}
- if hasattr(self, '_batch_input_shape'):
- config['batch_input_shape'] = self._batch_input_shape
- if hasattr(self, 'dtype'):
- config['dtype'] = self.dtype
- return config
-
- @classmethod
- def from_config(cls, config):
- """Creates a layer from its config.
-
- This method is the reverse of `get_config`,
- capable of instantiating the same layer from the config
- dictionary. It does not handle layer connectivity
- (handled by Network), nor weights (handled by `set_weights`).
-
- Arguments:
- config: A Python dictionary, typically the
- output of get_config.
-
- Returns:
- A layer instance.
- """
- return cls(**config)
+ if not self._inbound_nodes:
+ raise RuntimeError('The layer has never been called '
+ 'and thus has no defined ' + attr_name + '.')
+ if not len(self._inbound_nodes) > node_index:
+ raise ValueError('Asked to get ' + attr_name + ' at node ' +
+ str(node_index) + ', but the layer has only ' +
+ str(len(self._inbound_nodes)) + ' inbound nodes.')
+ values = getattr(self._inbound_nodes[node_index], attr)
+ if len(values) == 1:
+ return values[0]
+ else:
+ return values
@property
def _static_graph_friendly(self):
@@ -1687,55 +1581,6 @@
return self._call_is_graph_friendly
-@tf_export(
- 'keras.layers.InputSpec', v1=['keras.layers.InputSpec', 'layers.InputSpec'])
-class InputSpec(object):
- """Specifies the ndim, dtype and shape of every input to a layer.
-
- Every layer should expose (if appropriate) an `input_spec` attribute:
- a list of instances of InputSpec (one per input tensor).
-
- A None entry in a shape is compatible with any dimension,
- a None shape is compatible with any shape.
-
- Arguments:
- dtype: Expected DataType of the input.
- shape: Shape tuple, expected shape of the input
- (may include None for unchecked axes).
- ndim: Integer, expected rank of the input.
- max_ndim: Integer, maximum rank of the input.
- min_ndim: Integer, minimum rank of the input.
- axes: Dictionary mapping integer axes to
- a specific dimension value.
- """
-
- def __init__(self,
- dtype=None,
- shape=None,
- ndim=None,
- max_ndim=None,
- min_ndim=None,
- axes=None):
- self.dtype = dtype
- self.shape = shape
- if shape is not None:
- self.ndim = len(shape)
- else:
- self.ndim = ndim
- self.max_ndim = max_ndim
- self.min_ndim = min_ndim
- self.axes = axes or {}
-
- def __repr__(self):
- spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
- ('shape=' + str(self.shape)) if self.shape else '',
- ('ndim=' + str(self.ndim)) if self.ndim else '',
- ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
- ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
- ('axes=' + str(self.axes)) if self.axes else '']
- return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
-
-
class Node(object):
"""A `Node` describes the connectivity between two layers.
@@ -1848,192 +1693,12 @@
}
-def unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='',
- zero_based=False):
- """Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
-
- Arguments:
- name: String name to make unique.
- name_uid_map: An optional defaultdict(int) to use when creating unique
- names. If None (default), uses a per-Graph dictionary.
- avoid_names: An optional set or dict with names which should not be used. If
- None (default) does not avoid any names.
- namespace: Gets a name which is unique within the (graph, namespace). Layers
- which are not Networks use a blank namespace and so get graph-global
- names.
- zero_based: If True, name sequences start with no suffix (e.g. "dense",
- "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
-
- Returns:
- Unique string name.
-
- Example:
-
- ```python
- _unique_layer_name('dense') # dense_1
- _unique_layer_name('dense') # dense_2
- ```
- """
- if name_uid_map is None:
- name_uid_map = get_default_graph_uid_map()
- if avoid_names is None:
- avoid_names = set()
- proposed_name = None
- while proposed_name is None or proposed_name in avoid_names:
- name_key = (namespace, name)
- if zero_based:
- number = name_uid_map[name_key]
- if number:
- proposed_name = name + '_' + str(number)
- else:
- proposed_name = name
- name_uid_map[name_key] += 1
- else:
- name_uid_map[name_key] += 1
- proposed_name = name + '_' + str(name_uid_map[name_key])
- return proposed_name
-
-
-def have_all_keras_metadata(iterable_or_element):
- if not isinstance(iterable_or_element, (list, tuple)):
- iterable = [iterable_or_element]
- else:
- iterable = nest.flatten(iterable_or_element)
- return all([hasattr(x, '_keras_history') for x in iterable])
-
-
-def collect_previous_mask(input_tensors):
- """Retrieves the output mask(s) of the previous node.
-
- Arguments:
- input_tensors: A tensor or list of tensors.
-
- Returns:
- A mask tensor or list of mask tensors.
- """
- input_tensors = nest.flatten(input_tensors)
- masks = []
- for x in input_tensors:
- if hasattr(x, '_keras_mask'):
- mask = x._keras_mask # pylint: disable=protected-access
- masks.append(mask)
- else:
- masks.append(None)
- if len(masks) == 1:
- return masks[0]
- return masks
-
-
-def get_default_graph_uid_map():
- # TODO(fchollet): refactor this into backend.
- graph = ops.get_default_graph()
- name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS.get(graph, None)
- if name_uid_map is None:
- name_uid_map = collections_lib.defaultdict(int)
- backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map
- return name_uid_map
-
-
-def make_variable(name,
- shape=None,
- dtype=dtypes.float32,
- initializer=None,
- partition_info=None,
- trainable=None,
- caching_device=None,
- validate_shape=True,
- constraint=None,
- use_resource=None,
- collections=None,
- synchronization=tf_variables.VariableSynchronization.AUTO,
- aggregation=tf_variables.VariableAggregation.NONE,
- partitioner=None): # pylint: disable=unused-argument
- """Temporary util to create a variable (relies on `variable_scope.variable`).
-
- Some reuse-related technicalities prevent us from using
- `variable_scope.get_variable()` directly, so we use a subcomponent
- that has fewer constraints (`variable_scope.variable()`).
-
- In the longer term, it seems like a similar "default variable creator" method
- should exist in `CheckpointableBase` instead. When this happens, we can get
- rid of this temporary solution.
-
- TODO(fchollet): remove this method when no longer needed.
- TODO(fchollet): handle `partitioner` argument.
-
- Arguments:
- name: Variable name.
- shape: Variable shape.
- dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
- initializer: Initializer instance (callable).
- partition_info: Not handled at this time.
- trainable: Whether the variable should be part of the layer's
- "trainable_variables" (e.g. variables, biases)
- or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
- Note, if the current variable scope is marked as non-trainable
- then this parameter is ignored and any added variables are also
- marked as non-trainable. `trainable` defaults to `True` unless
- `synchronization` is set to `ON_READ`.
- caching_device: Passed to `tf.Variable`.
- validate_shape: Passed to `tf.Variable`.
- constraint: Constraint instance (callable).
- use_resource: Whether to use a `ResourceVariable`.
- collections: List of graph collections keys. The new variable is added to
- these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
- synchronization: Indicates when a distributed a variable will be
- aggregated. Accepted values are constants defined in the class
- `tf.VariableSynchronization`. By default the synchronization is set to
- `AUTO` and the current `DistributionStrategy` chooses
- when to synchronize. If `synchronization` is set to `ON_READ`,
- `trainable` must not be set to `True`.
- aggregation: Indicates how a distributed variable will be aggregated.
- Accepted values are constants defined in the class
- `tf.VariableAggregation`.
- partitioner: Not handled at this time.
-
- Returns:
- Variable instance.
- """
- initializing_from_value = False
- if initializer is not None and not callable(initializer):
- initializing_from_value = True
-
- with ops.init_scope():
- if initializing_from_value:
- init_val = initializer
- variable_dtype = None
- else:
- # Instantiate initializer if provided initializer is a type object.
- if isinstance(initializer, type(init_ops.Initializer)):
- initializer = initializer(dtype=dtype)
- init_val = lambda: initializer( # pylint: disable=g-long-lambda
- shape, dtype=dtype, partition_info=partition_info)
- variable_dtype = dtype.base_dtype
- if use_resource is None:
- use_resource = True
-
- # TODO(apassos,rohanj) figure out how to remove collections from here so we
- # can remove the V1.
- v = tf_variables.VariableV1(
- initial_value=init_val,
- name=name,
- trainable=trainable,
- caching_device=caching_device,
- dtype=variable_dtype,
- validate_shape=validate_shape,
- constraint=constraint,
- use_resource=use_resource,
- collections=collections,
- synchronization=synchronization,
- aggregation=aggregation)
- return v
-
-
def default(method):
"""Decorates a method to detect overrides in subclasses."""
method._is_default = True
return method
-def generate_placeholders_from_shape(shape):
- return array_ops.placeholder(shape=shape, dtype=backend.floatx())
+# Avoid breaking users who directly import this symbol from this file.
+# TODO(fchollet): remove this.
+InputSpec = input_spec.InputSpec # pylint:disable=invalid-name
diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py
new file mode 100644
index 0000000..d2f947f
--- /dev/null
+++ b/tensorflow/python/keras/engine/base_layer_utils.py
@@ -0,0 +1,236 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains private utilities used mainly by the base Layer class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections as collections_lib
+import enum
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.util import nest
+
+
+class CallConvention(enum.Enum):
+ """Calling conventions for passing `Layer` inputs to `Layer.call`."""
+ # The Layer takes inputs as its first argument, named "inputs" for
+ # compatibility with the signature of Layer.__call__. This is the mode assumed
+ # for Layers which are not subclassed Models.
+ EXPLICIT_INPUTS_ARGUMENT = 1
+ # The Layer takes a single positional argument, not named "inputs". It's
+ # treated like an "inputs" argument.
+ SINGLE_POSITIONAL_ARGUMENT = 2
+ # The Layer has multiple positional arguments to which its inputs should be
+ # bound.
+ POSITIONAL_ARGUMENTS_ARE_INPUTS = 3
+
+
+def create_mean_metric(value, name=None):
+ # TODO(psv): Remove this import when b/110718070 is fixed.
+ from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top
+ metric_obj = metrics_module.Mean(name=name)
+ result = metric_obj(value)
+ return metric_obj, result
+
+
+def make_variable(name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ partition_info=None,
+ trainable=None,
+ caching_device=None,
+ validate_shape=True,
+ constraint=None,
+ use_resource=None,
+ collections=None,
+ synchronization=tf_variables.VariableSynchronization.AUTO,
+ aggregation=tf_variables.VariableAggregation.NONE,
+ partitioner=None): # pylint: disable=unused-argument
+ """Temporary util to create a variable (relies on `variable_scope.variable`).
+
+ Some reuse-related technicalities prevent us from using
+ `variable_scope.get_variable()` directly, so we use a subcomponent
+ that has fewer constraints (`variable_scope.variable()`).
+
+ In the longer term, it seems like a similar "default variable creator" method
+ should exist in `CheckpointableBase` instead. When this happens, we can get
+ rid of this temporary solution.
+
+ TODO(fchollet): remove this method when no longer needed.
+ TODO(fchollet): handle `partitioner` argument.
+
+ Arguments:
+ name: Variable name.
+ shape: Variable shape.
+ dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
+ initializer: Initializer instance (callable).
+ partition_info: Not handled at this time.
+ trainable: Whether the variable should be part of the layer's
+ "trainable_variables" (e.g. variables, biases)
+ or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
+ Note, if the current variable scope is marked as non-trainable
+ then this parameter is ignored and any added variables are also
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
+ caching_device: Passed to `tf.Variable`.
+ validate_shape: Passed to `tf.Variable`.
+ constraint: Constraint instance (callable).
+ use_resource: Whether to use a `ResourceVariable`.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ `tf.VariableSynchronization`. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ `tf.VariableAggregation`.
+ partitioner: Not handled at this time.
+
+ Returns:
+ Variable instance.
+ """
+ initializing_from_value = False
+ if initializer is not None and not callable(initializer):
+ initializing_from_value = True
+
+ with ops.init_scope():
+ if initializing_from_value:
+ init_val = initializer
+ variable_dtype = None
+ else:
+ # Instantiate initializer if provided initializer is a type object.
+ if isinstance(initializer, type(init_ops.Initializer)):
+ initializer = initializer(dtype=dtype)
+ init_val = lambda: initializer( # pylint: disable=g-long-lambda
+ shape, dtype=dtype, partition_info=partition_info)
+ variable_dtype = dtype.base_dtype
+ if use_resource is None:
+ use_resource = True
+
+ # TODO(apassos,rohanj) figure out how to remove collections from here so we
+ # can remove the V1.
+ v = tf_variables.VariableV1(
+ initial_value=init_val,
+ name=name,
+ trainable=trainable,
+ caching_device=caching_device,
+ dtype=variable_dtype,
+ validate_shape=validate_shape,
+ constraint=constraint,
+ use_resource=use_resource,
+ collections=collections,
+ synchronization=synchronization,
+ aggregation=aggregation)
+ return v
+
+
+def get_default_graph_uid_map():
+ # TODO(fchollet): refactor this into backend.
+ graph = ops.get_default_graph()
+ name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS.get(graph, None)
+ if name_uid_map is None:
+ name_uid_map = collections_lib.defaultdict(int)
+ backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map
+ return name_uid_map
+
+
+def unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='',
+ zero_based=False):
+ """Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
+
+ Arguments:
+ name: String name to make unique.
+ name_uid_map: An optional defaultdict(int) to use when creating unique
+ names. If None (default), uses a per-Graph dictionary.
+ avoid_names: An optional set or dict with names which should not be used. If
+ None (default) does not avoid any names.
+ namespace: Gets a name which is unique within the (graph, namespace). Layers
+ which are not Networks use a blank namespace and so get graph-global
+ names.
+ zero_based: If True, name sequences start with no suffix (e.g. "dense",
+ "dense_1"). If False, naming is one-based ("dense_1", "dense_2").
+
+ Returns:
+ Unique string name.
+
+ Example:
+
+ ```python
+ _unique_layer_name('dense') # dense_1
+ _unique_layer_name('dense') # dense_2
+ ```
+ """
+ if name_uid_map is None:
+ name_uid_map = get_default_graph_uid_map()
+ if avoid_names is None:
+ avoid_names = set()
+ proposed_name = None
+ while proposed_name is None or proposed_name in avoid_names:
+ name_key = (namespace, name)
+ if zero_based:
+ number = name_uid_map[name_key]
+ if number:
+ proposed_name = name + '_' + str(number)
+ else:
+ proposed_name = name
+ name_uid_map[name_key] += 1
+ else:
+ name_uid_map[name_key] += 1
+ proposed_name = name + '_' + str(name_uid_map[name_key])
+ return proposed_name
+
+
+def collect_previous_mask(input_tensors):
+ """Retrieves the output mask(s) of the previous node.
+
+ Arguments:
+ input_tensors: A tensor or list of tensors.
+
+ Returns:
+ A mask tensor or list of mask tensors.
+ """
+ input_tensors = nest.flatten(input_tensors)
+ masks = []
+ for x in input_tensors:
+ if hasattr(x, '_keras_mask'):
+ mask = x._keras_mask # pylint: disable=protected-access
+ masks.append(mask)
+ else:
+ masks.append(None)
+ if len(masks) == 1:
+ return masks[0]
+ return masks
+
+
+def have_all_keras_metadata(iterable_or_element):
+ if not isinstance(iterable_or_element, (list, tuple)):
+ iterable = [iterable_or_element]
+ else:
+ iterable = nest.flatten(iterable_or_element)
+ return all(hasattr(x, '_keras_history') for x in iterable)
+
+
+def generate_placeholders_from_shape(shape):
+ return array_ops.placeholder(shape=shape, dtype=backend.floatx())
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index 41da393..8b00761 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -23,15 +23,16 @@
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
@@ -145,11 +146,14 @@
for e in distribution_strategy.unwrap(flattened)]
-def validate_callbacks(input_callbacks):
+def validate_callbacks(input_callbacks, optimizer, current_strategy):
"""Validate whether given callbacks are supported by DistributionStrategy.
Args:
input_callbacks: List of callbacks passed by the user to fit.
+ optimizer: Optimizer instance used to train the model.
+ current_strategy: The DistributionStrategy used to distribute training
+ and validation.
Raises:
ValueError: If `LearningRateScheduler` or `ReduceLROnPlateau` is one of the
@@ -171,12 +175,18 @@
'these attributes are not set. You can access each of '
'the individual distributed models using the '
'`_grouped_model` attribute of your original model.')
- if isinstance(callback, callbacks.LearningRateScheduler):
- raise ValueError('LearningRateScheduler callback is not supported with '
- 'DistributionStrategy.')
- if isinstance(callback, callbacks.ReduceLROnPlateau):
- raise ValueError('ReduceLROnPlateau callback is not supported with '
- 'DistributionStrategy.')
+ if isinstance(callback, (callbacks.LearningRateScheduler,
+ callbacks.ReduceLROnPlateau)):
+ strategy_name = current_strategy.__class__.__name__
+ # TODO(anjalisridhar): We might need to add a condition for multi
+ # worker strategy when we support it in Keras.
+ if is_tpu_strategy(current_strategy):
+ raise ValueError('%s callback is not supported with %s.' %
+ (callback, strategy_name))
+
+ if not isinstance(optimizer, optimizer_v2.OptimizerV2):
+ raise ValueError('You must specify a Keras Optimizer V2 when using '
+ '%s callback with DistributionStrategy.' % callback)
# If users want to use the TensorBoard callback they cannot use certain
# features of the callback that involve accessing model attributes and
@@ -350,6 +360,7 @@
session = session_module.Session(
config=dc_session_config, target=worker_context.master_target)
else:
+ distribution_strategy.configure(session_config)
session = session_module.Session(config=session_config)
K.set_session(session)
@@ -381,9 +392,13 @@
if is_tpu_strategy(distribution_strategy):
for i in [x, y]:
- if isinstance(i, dataset_ops.Dataset):
+ if isinstance(i, dataset_ops.DatasetV2):
shapes = nest.flatten(i.output_shapes)
- if any([not s.is_fully_defined() for s in shapes]):
+ try:
+ s = next(s for s in shapes if not s.is_fully_defined())
+ except StopIteration:
+ continue
+ else:
raise ValueError(
'Using TPUs currently requires fully defined shapes. Either use '
'set_shape() on the input tensors or use '
@@ -391,14 +406,11 @@
'Found unknown shape {} in input {}.'.format(s, i))
-# TODO(b/118776054): Currently we support global batch size for TPUStrategy
-# and CoreMirroredStrategy only. Remove this check when contrib MirroredStrategy
-# is no longer needed.
+# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
+# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
+# no longer needed.
def global_batch_size_supported(distribution_strategy):
- strategy_name = distribution_strategy.__class__.__name__
- # TODO(priyag): Change this to whatever condition makes sense when
- # CoreMirroredStrategy is moved to core and renamed.
- return strategy_name in ('TPUStrategy', 'CoreMirroredStrategy')
+ return distribution_strategy.extended._global_batch_size # pylint: disable=protected-access
# TODO(sourabhbajaj): Remove this once we use the same API for all strategies.
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py
index 590b935..9874efe 100644
--- a/tensorflow/python/keras/engine/input_layer.py
+++ b/tensorflow/python/keras/engine/input_layer.py
@@ -19,12 +19,10 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.utils import tf_utils
-from tensorflow.python.ops import array_ops
from tensorflow.python.util.tf_export import tf_export
@@ -94,19 +92,19 @@
else:
batch_input_shape = None
graph = backend.get_graph()
- with context.graph_mode():
- with graph.as_default():
- # In graph mode, create a graph placeholder to call the layer on.
- if sparse:
- input_tensor = array_ops.sparse_placeholder(
- shape=batch_input_shape,
- dtype=dtype,
- name=self.name)
- else:
- input_tensor = array_ops.placeholder(
- shape=batch_input_shape,
- dtype=dtype,
- name=self.name)
+ with graph.as_default():
+ # In graph mode, create a graph placeholder to call the layer on.
+ if sparse:
+ input_tensor = backend.placeholder(
+ shape=batch_input_shape,
+ dtype=dtype,
+ name=self.name,
+ sparse=True)
+ else:
+ input_tensor = backend.placeholder(
+ shape=batch_input_shape,
+ dtype=dtype,
+ name=self.name)
self.is_placeholder = True
self._batch_input_shape = batch_input_shape
diff --git a/tensorflow/python/keras/engine/input_spec.py b/tensorflow/python/keras/engine/input_spec.py
new file mode 100644
index 0000000..7277c16
--- /dev/null
+++ b/tensorflow/python/keras/engine/input_spec.py
@@ -0,0 +1,170 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Contains the InputSpec class."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import zip # pylint: disable=redefined-builtin
+
+from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('keras.layers.InputSpec',
+ v1=['keras.layers.InputSpec', 'layers.InputSpec'])
+class InputSpec(object):
+ """Specifies the ndim, dtype and shape of every input to a layer.
+
+ Every layer should expose (if appropriate) an `input_spec` attribute:
+ a list of instances of InputSpec (one per input tensor).
+
+ A None entry in a shape is compatible with any dimension,
+ a None shape is compatible with any shape.
+
+ Arguments:
+ dtype: Expected DataType of the input.
+ shape: Shape tuple, expected shape of the input
+ (may include None for unchecked axes).
+ ndim: Integer, expected rank of the input.
+ max_ndim: Integer, maximum rank of the input.
+ min_ndim: Integer, minimum rank of the input.
+ axes: Dictionary mapping integer axes to
+ a specific dimension value.
+ """
+
+ def __init__(self,
+ dtype=None,
+ shape=None,
+ ndim=None,
+ max_ndim=None,
+ min_ndim=None,
+ axes=None):
+ self.dtype = dtype
+ self.shape = shape
+ if shape is not None:
+ self.ndim = len(shape)
+ else:
+ self.ndim = ndim
+ self.max_ndim = max_ndim
+ self.min_ndim = min_ndim
+ self.axes = axes or {}
+
+ def __repr__(self):
+ spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
+ ('shape=' + str(self.shape)) if self.shape else '',
+ ('ndim=' + str(self.ndim)) if self.ndim else '',
+ ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
+ ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
+ ('axes=' + str(self.axes)) if self.axes else '']
+ return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
+
+
+def assert_input_compatibility(input_spec, inputs, layer_name):
+ """Checks compatibility between the layer and provided inputs.
+
+ This checks that the tensor(s) `inputs` verify the input assumptions
+ of a layer (if any). If not, a clear and actional exception gets raised.
+
+ Arguments:
+ input_spec: An InputSpec instance, or None.
+ inputs: Input tensor or list of input tensors.
+ layer_name: String, name of the layer (for error message formatting).
+
+ Raises:
+ ValueError: in case of mismatch between
+ the provided inputs and the expectations of the layer.
+ """
+ if not input_spec:
+ return
+ if not isinstance(input_spec, (list, tuple)):
+ input_spec = nest.flatten(input_spec)
+
+ inputs = nest.flatten(inputs)
+ if len(inputs) != len(input_spec):
+ raise ValueError('Layer ' + layer_name + ' expects ' +
+ str(len(input_spec)) + ' inputs, '
+ 'but it received ' + str(len(inputs)) +
+ ' input tensors. Inputs received: ' + str(inputs))
+ for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
+ if spec is None:
+ continue
+
+ if (spec.ndim is not None or
+ spec.min_ndim is not None or
+ spec.max_ndim is not None):
+ if x.shape.ndims is None:
+ raise ValueError('Input ' + str(input_index) + ' of layer ' +
+ layer_name + ' is incompatible with the layer: '
+ 'its rank is undefined, but the layer requires a '
+ 'defined rank.')
+
+ # Check ndim.
+ if spec.ndim is not None:
+ ndim = x.shape.ndims
+ if ndim != spec.ndim:
+ raise ValueError('Input ' + str(input_index) + ' of layer ' +
+ layer_name + ' is incompatible with the layer: '
+ 'expected ndim=' + str(spec.ndim) + ', found ndim=' +
+ str(ndim) + '. Full shape received: ' +
+ str(x.shape.as_list()))
+ if spec.max_ndim is not None:
+ ndim = x.shape.ndims
+ if ndim is not None and ndim > spec.max_ndim:
+ raise ValueError('Input ' + str(input_index) + ' of layer ' +
+ layer_name + ' is incompatible with the layer: '
+ 'expected max_ndim=' + str(spec.max_ndim) +
+ ', found ndim=' + str(ndim))
+ if spec.min_ndim is not None:
+ ndim = x.shape.ndims
+ if ndim is not None and ndim < spec.min_ndim:
+ raise ValueError('Input ' + str(input_index) + ' of layer ' +
+ layer_name + ' is incompatible with the layer: '
+ ': expected min_ndim=' + str(spec.min_ndim) +
+ ', found ndim=' + str(ndim) +
+ '. Full shape received: ' +
+ str(x.shape.as_list()))
+ # Check dtype.
+ if spec.dtype is not None:
+ if x.dtype != spec.dtype:
+ raise ValueError('Input ' + str(input_index) + ' of layer ' +
+ layer_name + ' is incompatible with the layer: '
+ 'expected dtype=' + str(spec.dtype) +
+ ', found dtype=' + str(x.dtype))
+ # Check specific shape axes.
+ if spec.axes:
+ shape = x.shape.as_list()
+ if shape is not None:
+ for axis, value in spec.axes.items():
+ if hasattr(value, 'value'):
+ value = value.value
+ if value is not None and shape[int(axis)] not in {value, None}:
+ raise ValueError(
+ 'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
+ ' incompatible with the layer: expected axis ' + str(axis) +
+ ' of input shape to have value ' + str(value) +
+ ' but received input with shape ' + str(shape))
+ # Check shape.
+ if spec.shape is not None:
+ shape = x.shape.as_list()
+ if shape is not None:
+ for spec_dim, dim in zip(spec.shape, shape):
+ if spec_dim is not None and dim is not None:
+ if spec_dim != dim:
+ raise ValueError('Input ' + str(input_index) +
+ ' is incompatible with layer ' + layer_name +
+ ': expected shape=' + str(spec.shape) +
+ ', found shape=' + str(shape))
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 4163176..1040fd8 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -36,6 +36,7 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import layer_utils
@@ -162,7 +163,8 @@
@checkpointable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
- self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
+ self._call_convention = (base_layer_utils
+ .CallConvention.EXPLICIT_INPUTS_ARGUMENT)
# Normalize and set self.inputs, self.outputs.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
@@ -305,7 +307,7 @@
return self._call_is_graph_friendly
def _determine_call_convention(self, call_argspec):
- """Decides how `self.call()` is invoked. See base_layer.CallConvention."""
+ """Decides how `self.call()` is invoked. See `CallConvention`."""
if call_argspec.varargs:
may_take_single_argument = False
else:
@@ -337,11 +339,11 @@
"Model.call() takes a single positional argument (to which "
"inputs are passed by convention) and a separate 'inputs' "
"argument. Unable to determine which arguments are inputs.")
- return base_layer.CallConvention.SINGLE_POSITIONAL_ARGUMENT
+ return base_layer_utils.CallConvention.SINGLE_POSITIONAL_ARGUMENT
if 'inputs' in call_argspec.args:
- return base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
+ return base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT
else:
- return base_layer.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS
+ return base_layer_utils.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS
def _track_layers(self, layers):
"""Add Checkpointable dependencies on a list of Layers."""
@@ -428,8 +430,8 @@
@property
def stateful(self):
- return any([(hasattr(layer, 'stateful') and layer.stateful)
- for layer in self.layers])
+ return any((hasattr(layer, 'stateful') and layer.stateful)
+ for layer in self.layers)
def reset_states(self):
for layer in self.layers:
@@ -807,10 +809,10 @@
graph = func_graph.FuncGraph('graph')
with graph.as_default():
if isinstance(input_shape, list):
- x = [base_layer.generate_placeholders_from_shape(shape)
+ x = [base_layer_utils.generate_placeholders_from_shape(shape)
for shape in input_shape]
else:
- x = base_layer.generate_placeholders_from_shape(input_shape)
+ x = base_layer_utils.generate_placeholders_from_shape(input_shape)
kwargs = {}
num_call_args = len(tf_inspect.getfullargspec(self.call).args)
diff --git a/tensorflow/python/keras/engine/saving.py b/tensorflow/python/keras/engine/saving.py
index 22c48e3..54d9e32 100644
--- a/tensorflow/python/keras/engine/saving.py
+++ b/tensorflow/python/keras/engine/saving.py
@@ -917,7 +917,7 @@
chunked_data = np.array_split(data_npy, num_chunks)
# This will never loop forever thanks to the test above.
- while any([x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data]):
+ while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
num_chunks += 1
chunked_data = np.array_split(data_npy, num_chunks)
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index f376f08..375d101 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -32,6 +32,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import training
+from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -992,5 +993,57 @@
AssertionError, 'Nothing except the root object matched'):
m.load_weights(save_path)
+ @test_util.run_in_graph_and_eager_modes
+ def test_directory_passed(self):
+ m = keras.Model()
+ v = m.add_weight(name='v', shape=[])
+ self.evaluate(v.assign(42.))
+ prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'ckpt/')
+ m.save_weights(prefix)
+ self.evaluate(v.assign(2.))
+ m.load_weights(prefix)
+ self.assertEqual(42., self.evaluate(v))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_relative_path(self):
+ m = keras.Model()
+ v = m.add_weight(name='v', shape=[])
+ os.chdir(self.get_temp_dir())
+
+ prefix = 'ackpt'
+ self.evaluate(v.assign(42.))
+ m.save_weights(prefix)
+ self.assertTrue(file_io.file_exists('ackpt.index'))
+ self.evaluate(v.assign(1.))
+ m.load_weights(prefix)
+ self.assertEqual(42., self.evaluate(v))
+
+ prefix = 'subdir/ackpt'
+ self.evaluate(v.assign(43.))
+ m.save_weights(prefix)
+ self.assertTrue(file_io.file_exists('subdir/ackpt.index'))
+ self.evaluate(v.assign(2.))
+ m.load_weights(prefix)
+ self.assertEqual(43., self.evaluate(v))
+
+ prefix = 'ackpt/'
+ self.evaluate(v.assign(44.))
+ m.save_weights(prefix)
+ self.assertTrue(file_io.file_exists('ackpt/.index'))
+ self.evaluate(v.assign(3.))
+ m.load_weights(prefix)
+ self.assertEqual(44., self.evaluate(v))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_nonexistant_prefix_directory(self):
+ m = keras.Model()
+ v = m.add_weight(name='v', shape=[])
+ self.evaluate(v.assign(42.))
+ prefix = os.path.join(self.get_temp_dir(), '{}'.format(ops.uid()), 'bckpt')
+ m.save_weights(prefix)
+ self.evaluate(v.assign(2.))
+ m.load_weights(prefix)
+ self.assertEqual(42., self.evaluate(v))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index d926b53..4d3fffb 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -41,6 +41,8 @@
from tensorflow.python.keras.engine.network import Network
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils.generic_utils import slice_arrays
+from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -568,16 +570,16 @@
'" missing from loss dictionary. We assume '
'this was done on purpose. The fit and evaluate APIs will not be '
'expecting any data to be passed to "' + name + '".')
- loss_functions.append(losses.get(loss.get(name)))
+ loss_functions.append(training_utils.get_loss_function(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
raise ValueError('When passing a list as loss, '
'it should have one entry per model outputs. '
'The model has ' + str(len(self.outputs)) +
' outputs, but you passed loss=' + str(loss))
- loss_functions = [losses.get(l) for l in loss]
+ loss_functions = [training_utils.get_loss_function(l) for l in loss]
else:
- loss_function = losses.get(loss)
+ loss_function = training_utils.get_loss_function(loss)
loss_functions = [loss_function for _ in range(len(self.outputs))]
self.loss_functions = loss_functions
@@ -693,11 +695,15 @@
target = None
if target is None or K.is_placeholder(target):
if target is None:
+ target_dtype = losses.LABEL_DTYPES_FOR_LOSSES.get(
+ self.loss_functions[i],
+ K.dtype(self.outputs[i]))
+
target = K.placeholder(
ndim=len(shape),
name=name + '_target',
sparse=K.is_sparse(self.outputs[i]),
- dtype=K.dtype(self.outputs[i]))
+ dtype=target_dtype)
self._feed_targets.append(target)
self._feed_outputs.append(self.outputs[i])
self._feed_output_names.append(name)
@@ -726,8 +732,21 @@
mask = masks[i]
loss_weight = loss_weights_list[i]
with K.name_scope(self.output_names[i] + '_loss'):
- weighted_loss = training_utils.weighted_masked_objective(loss_fn)
- output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
+ if isinstance(loss_fn, losses.Loss):
+ if mask is not None:
+ mask = math_ops.cast(mask, y_pred.dtype)
+ # Update weights with mask.
+ if sample_weight is None:
+ sample_weight = mask
+ else:
+ # Update dimensions of weights to match with mask if possible.
+ mask, _, sample_weight = squeeze_or_expand_dimensions(
+ mask, None, sample_weight)
+ sample_weight *= mask
+ output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
+ else:
+ weighted_loss = training_utils.weighted_masked_objective(loss_fn)
+ output_loss = weighted_loss(y_true, y_pred, sample_weight, mask)
if len(self.outputs) > 1:
# Keep track of the un-aggregated loss result tensor.
@@ -735,8 +754,10 @@
'_loss'] = output_loss
# Keep track of stateful result tensor and function for the loss.
+ loss_name = loss_fn.name if isinstance(
+ loss_fn, losses.Loss) else loss_fn.__name__
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
- loss_fn, name=loss_fn.__name__)
+ loss_fn, name=loss_name)
result_tensor = training_utils.call_metric_function(
mean_wrapped_loss,
y_true,
@@ -996,7 +1017,7 @@
# TODO(anjalisridhar): Remove this check once we refactor the
# _standardize_user_data code path. This check is already present elsewhere
# in the codebase.
- if check_steps and isinstance(x, dataset_ops.Dataset) and steps is None:
+ if check_steps and isinstance(x, dataset_ops.DatasetV2) and steps is None:
raise ValueError('When using Datasets as input, '
'you should specify the `{steps_name}` argument.'
.format(steps_name=steps_name))
@@ -1039,7 +1060,7 @@
x = dataset_ops.Dataset.from_tensor_slices(var_x)
x = x.batch(batch_size, drop_remainder=drop_remainder)
- assert isinstance(x, dataset_ops.Dataset)
+ assert isinstance(x, dataset_ops.DatasetV2)
with self._distribution_strategy.scope():
iterator = self._distribution_strategy.make_dataset_iterator(x)
@@ -1128,7 +1149,7 @@
shuffle=shuffle)
return iterator, None, None
- if isinstance(x, dataset_ops.Dataset):
+ if isinstance(x, dataset_ops.DatasetV2):
if context.executing_eagerly():
x = x.make_one_shot_iterator()
else:
@@ -1227,7 +1248,10 @@
# to match the value shapes.
if not self.inputs:
is_build_called = True
- self._set_inputs(x)
+ cast_inputs = x
+ if training_utils.has_tensors(x):
+ cast_inputs = training_utils.cast_if_floating_dtype(x)
+ self._set_inputs(cast_inputs)
else:
dict_inputs = isinstance(self.inputs, dict)
if dict_inputs and context.executing_eagerly():
@@ -1243,6 +1267,8 @@
if not self._is_compiled:
# On-the-fly compilation of the model.
# We need to use `y` to set the model targets.
+ if training_utils.has_tensors(y):
+ y = training_utils.cast_if_floating_dtype(y)
if isinstance(y, (list, tuple)):
if not all(isinstance(v, np.ndarray) or
tensor_util.is_tensor(v) for v in y):
@@ -1650,7 +1676,8 @@
# Validate and standardize user data.
if self._distribution_strategy:
- distributed_training_utils.validate_callbacks(callbacks)
+ distributed_training_utils.validate_callbacks(callbacks, self.optimizer,
+ self._distribution_strategy)
distributed_training_utils.validate_inputs(
x, y, self._distribution_strategy)
@@ -1682,7 +1709,7 @@
if validation_data:
if (isinstance(validation_data, iterator_ops.Iterator) or
isinstance(validation_data, iterator_ops.EagerIterator) or
- isinstance(validation_data, dataset_ops.Dataset)):
+ isinstance(validation_data, dataset_ops.DatasetV2)):
val_x = validation_data
val_y = None
val_sample_weight = None
@@ -2186,7 +2213,7 @@
inputs, _, _ = self._standardize_user_data(x)
if self.run_eagerly:
if (isinstance(inputs, iterator_ops.EagerIterator) or
- (isinstance(inputs, dataset_ops.Dataset))):
+ (isinstance(inputs, dataset_ops.DatasetV2))):
inputs = training_utils.cast_if_floating_dtype(inputs)
elif isinstance(inputs, collections.Sequence):
inputs = [
@@ -2462,9 +2489,7 @@
def __init__(self, model):
super(DistributedCallbackModel, self).__init__()
- # TODO(anjalisridhar): Right now the only attributes set are the layer and
- # weights. We may need to set additional attributes as needed since we have
- # not called compile on this model.
+ self.optimizer = model.optimizer
def set_original_model(self, orig_model):
self._original_model = orig_model
diff --git a/tensorflow/python/keras/engine/training_dataset_test.py b/tensorflow/python/keras/engine/training_dataset_test.py
index e8b884e..e79e584 100644
--- a/tensorflow/python/keras/engine/training_dataset_test.py
+++ b/tensorflow/python/keras/engine/training_dataset_test.py
@@ -20,6 +20,8 @@
import logging
+from absl.testing import parameterized
+
import numpy as np
from tensorflow.python import keras
@@ -28,16 +30,24 @@
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.rmsprop import RMSPropOptimizer
-class TestTrainingWithDatasetIterators(test.TestCase):
+class TestTrainingWithDatasetIterators(test.TestCase, parameterized.TestCase):
+ @parameterized.parameters(
+ {'model': 'functional'},
+ {'model': 'subclass'},
+ )
@tf_test_util.run_in_graph_and_eager_modes
- def test_training_and_eval_methods_on_iterators_single_io(self):
- model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ def test_training_and_eval_methods_on_iterators_single_io(self, model):
+ if model == 'functional':
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ elif model == 'subclass':
+ model = testing_utils.get_small_sequential_mlp(1, 4)
optimizer = RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
metrics = ['mae', metrics_module.CategoricalAccuracy()]
@@ -137,7 +147,7 @@
'dataset iterator ran out of data')
-class TestTrainingWithDataset(test.TestCase):
+class TestTrainingWithDataset(test.TestCase, parameterized.TestCase):
@tf_test_util.run_in_graph_and_eager_modes
def test_calling_model_on_same_dataset(self):
@@ -240,20 +250,29 @@
model.evaluate(dataset, steps=2, verbose=1)
model.predict(dataset, steps=2)
+ @parameterized.parameters(
+ {'model': 'functional'},
+ {'model': 'subclass'},
+ )
@tf_test_util.run_in_graph_and_eager_modes
- def test_dataset_with_sparse_labels(self):
- model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
- optimizer = RMSPropOptimizer(learning_rate=0.001)
- loss = 'sparse_categorical_crossentropy'
- model.compile(optimizer, loss)
+ def test_dataset_with_sparse_labels(self, model):
+ if model == 'functional':
+ model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3)
+ elif model == 'subclass':
+ model = testing_utils.get_small_sequential_mlp(1, 4)
- inputs = np.zeros((10, 3))
- targets = np.random.randint(0, 4, size=10, dtype=np.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ for loss in ['sparse_categorical_crossentropy',
+ losses_impl.sparse_softmax_cross_entropy]:
+ optimizer = RMSPropOptimizer(learning_rate=0.001)
+ model.compile(optimizer, loss)
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.random.randint(0, 4, size=10, dtype=np.int32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = dataset.batch(10)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
def test_dataset_input_shape_validation(self):
with self.cached_session():
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 878451d..d168323 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -36,7 +36,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.util import nest
@@ -109,7 +109,7 @@
mode=_Mode.TRAIN)
(grouped_inputs, grouped_outputs, grouped_updates,
- grouped_session_args) = current_strategy.call_for_each_replica(
+ grouped_session_args) = current_strategy.extended.call_for_each_replica(
_per_device_fit_function, args=(model._grouped_model_train,))
(all_inputs, all_outputs, all_updates,
all_session_args) = distributed_training_utils.unwrap_values(
@@ -152,7 +152,7 @@
name='steps_per_run')
with current_strategy.scope():
- ctx = current_strategy.run_steps_on_dataset(
+ ctx = current_strategy.extended.experimental_run_steps_on_iterator(
step_fn, iterator, iterations=steps_per_run,
initial_loop_values=initial_loop_values)
@@ -300,7 +300,7 @@
mode=_Mode.TEST)
(grouped_inputs, grouped_outputs, grouped_updates,
- grouped_session_args) = current_strategy.call_for_each_replica(
+ grouped_session_args) = current_strategy.extended.call_for_each_replica(
_per_device_eval_function, args=(model._grouped_model_test,))
(all_inputs, all_outputs, all_updates,
@@ -335,7 +335,7 @@
with current_strategy.scope():
# TODO(priyag): Use steps_per_run when we use new metrics as they will
# allow handling metric computation at each step using variables.
- ctx = current_strategy.run_steps_on_dataset(
+ ctx = current_strategy.extended.experimental_run_steps_on_iterator(
step_fn, iterator, iterations=1,
initial_loop_values=initial_loop_values)
@@ -414,7 +414,7 @@
mode=_Mode.PREDICT)
(grouped_inputs, grouped_outputs, grouped_updates,
- grouped_session_args) = current_strategy.call_for_each_replica(
+ grouped_session_args) = current_strategy.extended.call_for_each_replica(
_per_device_predict_function, args=(model._grouped_model_predict,))
(all_inputs, all_outputs, all_updates,
@@ -445,7 +445,7 @@
with current_strategy.scope():
# TODO(priyag, sourabhbajaj): Support steps_per_run if/when we add outfeed.
- ctx = current_strategy.run_steps_on_dataset(
+ ctx = current_strategy.extended.experimental_run_steps_on_iterator(
step_fn, iterator, iterations=1,
initial_loop_values=initial_loop_values)
@@ -528,7 +528,7 @@
inputs=None, targets=None, mode=None):
"""Create a cloned model on each replica."""
with strategy.scope():
- grouped_model = strategy.call_for_each_replica(
+ grouped_model = strategy.extended.call_for_each_replica(
_clone_and_build_model, args=(model, inputs, targets))
if mode is _Mode.TRAIN:
model._grouped_model_train = grouped_model
@@ -583,7 +583,7 @@
# Create train ops on each of the devices when we call
# `_per_device_fit_function`.
(grouped_inputs, grouped_outputs, grouped_updates,
- grouped_session_args) = strategy.call_for_each_replica(
+ grouped_session_args) = strategy.extended.call_for_each_replica(
_per_device_function, args=(model._grouped_model,))
if mode == 'train':
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index b2dace84..cd85c36 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -31,9 +31,11 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras import callbacks as cbks
+from tensorflow.python.keras import losses as losses_module
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
@@ -128,11 +130,24 @@
else:
weights = None
mask = masks[i]
-
- weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
with backend.name_scope(model.output_names[i] + '_loss'):
- output_loss = weighted_masked_fn(
- targets[i], outs[i], weights, mask=mask)
+ if isinstance(loss_fn, losses_module.Loss):
+ if mask is not None:
+ mask = math_ops.cast(mask, outs[i].dtype)
+ # Update weights with mask.
+ if weights is None:
+ weights = mask
+ else:
+ # Update dimensions of weights to match with mask if possible.
+ mask, _, weights = squeeze_or_expand_dimensions(
+ mask, None, weights)
+ weights *= mask
+ output_loss = loss_fn(targets[i], outs[i], sample_weight=weights)
+ else:
+ weighted_masked_fn = training_utils.weighted_masked_objective(loss_fn)
+ output_loss = weighted_masked_fn(
+ targets[i], outs[i], weights, mask=mask)
+
# If the number of outputs is 1 then we don't append the loss metric
# associated with each model output. When there are multiple outputs
# associated with a model, each output's loss is calculated and returned
@@ -351,8 +366,10 @@
output_loss_metrics = []
for i in range(len(model.outputs)):
loss_fn = model.loss_functions[i]
+ loss_name = loss_fn.name if isinstance(
+ loss_fn, losses_module.Loss) else loss_fn.__name__
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
- loss_fn, name=loss_fn.__name__)
+ loss_fn, name=loss_name)
output_loss_metrics.append(mean_wrapped_loss)
num_samples = 0
@@ -744,8 +761,10 @@
output_loss_metrics = []
for i in range(len(model.outputs)):
loss_fn = model.loss_functions[i]
+ loss_name = loss_fn.name if isinstance(
+ loss_fn, losses_module.Loss) else loss_fn.__name__
mean_wrapped_loss = metrics_module.MeanMetricWrapper(
- loss_fn, name=loss_fn.__name__)
+ loss_fn, name=loss_name)
output_loss_metrics.append(mean_wrapped_loss)
callbacks.on_train_begin()
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index 45247a2..e7310a7 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -67,7 +67,7 @@
else:
raise ValueError('Please specify the `steps_per_epoch` argument.')
- if (isinstance(validation_data, dataset_ops.Dataset) and
+ if (isinstance(validation_data, dataset_ops.DatasetV2) and
context.executing_eagerly()):
validation_data = validation_data.make_one_shot_iterator()
val_gen = (data_utils.is_generator_or_sequence(validation_data) or
@@ -388,7 +388,9 @@
if isinstance(generator_output, tuple):
# Compatibility with the generators
# used for training.
- if len(generator_output) == 2:
+ if len(generator_output) == 1:
+ x = generator_output[0]
+ elif len(generator_output) == 2:
x, _ = generator_output
elif len(generator_output) == 3:
x, _, _ = generator_output
diff --git a/tensorflow/python/keras/engine/training_generator_test.py b/tensorflow/python/keras/engine/training_generator_test.py
index 88e8943..42cfa3b 100644
--- a/tensorflow/python/keras/engine/training_generator_test.py
+++ b/tensorflow/python/keras/engine/training_generator_test.py
@@ -21,220 +21,269 @@
import os
import unittest
+from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
from tensorflow.python.training.rmsprop import RMSPropOptimizer
-class TestGeneratorMethods(test.TestCase):
+def custom_generator(mode=2):
+ batch_size = 10
+ num_samples = 50
+ arr_data = np.random.random((num_samples, 2))
+ arr_labels = np.random.random((num_samples, 4))
+ arr_weights = np.random.random((num_samples,))
+ i = 0
+ while True:
+ batch_index = i * batch_size % num_samples
+ i += 1
+ start = batch_index
+ end = start + batch_size
+ x = arr_data[start: end]
+ y = arr_labels[start: end]
+ w = arr_weights[start: end]
+ if mode == 1:
+ yield x
+ elif mode == 2:
+ yield x, y
+ else:
+ yield x, y, w
+
+
+@tf_test_util.run_all_in_graph_and_eager_modes
+class TestGeneratorMethods(test.TestCase, parameterized.TestCase):
@unittest.skipIf(
os.name == 'nt',
'use_multiprocessing=True does not work on windows properly.')
- def test_generator_methods(self):
- arr_data = np.random.random((50, 2))
- arr_labels = np.random.random((50,))
+ @parameterized.parameters('sequential', 'functional')
+ def test_fit_generator_method(self, model_type):
+ if model_type == 'sequential':
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=3, num_classes=4, input_dim=2)
+ else:
+ model = testing_utils.get_small_functional_mlp(
+ num_hidden=3, num_classes=4, input_dim=2)
+ model.compile(
+ loss='mse',
+ optimizer='sgd',
+ metrics=['mae', metrics_module.CategoricalAccuracy()])
- def custom_generator():
- batch_size = 10
- num_samples = 50
- while True:
- batch_index = np.random.randint(0, num_samples - batch_size)
- start = batch_index
- end = start + batch_size
- x = arr_data[start: end]
- y = arr_labels[start: end]
- yield x, y
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ workers=4,
+ use_multiprocessing=True)
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False,
+ validation_data=custom_generator(),
+ validation_steps=10)
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ validation_data=custom_generator(),
+ validation_steps=1,
+ workers=0)
- with self.cached_session():
- x = keras.Input((2,))
- y = keras.layers.Dense(1)(x)
- fn_model = keras.models.Model(x, y)
- fn_model.compile(
- loss='mse',
- optimizer='sgd',
- metrics=['mae', metrics_module.CategoricalAccuracy()])
+ @unittest.skipIf(
+ os.name == 'nt',
+ 'use_multiprocessing=True does not work on windows properly.')
+ @parameterized.parameters('sequential', 'functional')
+ def test_evaluate_generator_method(self, model_type):
+ if model_type == 'sequential':
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=3, num_classes=4, input_dim=2)
+ else:
+ model = testing_utils.get_small_functional_mlp(
+ num_hidden=3, num_classes=4, input_dim=2)
+ model.compile(
+ loss='mse',
+ optimizer='sgd',
+ metrics=['mae', metrics_module.CategoricalAccuracy()])
+ model.summary()
- seq_model = keras.models.Sequential()
- seq_model.add(keras.layers.Dense(1, input_shape=(2,)))
- seq_model.compile(loss='mse', optimizer='sgd')
+ model.evaluate_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ workers=2,
+ verbose=1,
+ use_multiprocessing=True)
+ model.evaluate_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.evaluate_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False,
+ workers=0)
- for model in [fn_model, seq_model]:
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
+ @unittest.skipIf(
+ os.name == 'nt',
+ 'use_multiprocessing=True does not work on windows properly.')
+ @parameterized.parameters('sequential', 'functional')
+ def test_predict_generator_method(self, model_type):
+ if model_type == 'sequential':
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=3, num_classes=4, input_dim=2)
+ else:
+ model = testing_utils.get_small_functional_mlp(
+ num_hidden=3, num_classes=4, input_dim=2)
+ model.compile(
+ loss='mse',
+ optimizer='sgd',
+ metrics=['mae', metrics_module.CategoricalAccuracy()])
+
+ model.predict_generator(custom_generator(),
+ steps=5,
max_queue_size=10,
- workers=4,
+ workers=2,
use_multiprocessing=True)
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
+ model.predict_generator(custom_generator(),
+ steps=5,
max_queue_size=10,
use_multiprocessing=False)
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
+ model.predict_generator(custom_generator(),
+ steps=5,
max_queue_size=10,
- use_multiprocessing=False,
- validation_data=custom_generator(),
- validation_steps=10)
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- validation_data=custom_generator(),
- validation_steps=1,
workers=0)
- model.predict_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- workers=2,
- use_multiprocessing=True)
- model.predict_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- use_multiprocessing=False)
- model.predict_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- workers=0)
- model.evaluate_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- workers=2,
- verbose=1,
- use_multiprocessing=True)
- model.evaluate_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- use_multiprocessing=False)
- model.evaluate_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- use_multiprocessing=False,
- workers=0)
+ # Test generator with just inputs (no targets)
+ model.predict_generator(custom_generator(mode=1),
+ steps=5,
+ max_queue_size=10,
+ workers=2,
+ use_multiprocessing=True)
+ model.predict_generator(custom_generator(mode=1),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.predict_generator(custom_generator(mode=1),
+ steps=5,
+ max_queue_size=10,
+ workers=0)
def test_generator_methods_with_sample_weights(self):
- arr_data = np.random.random((50, 2))
- arr_labels = np.random.random((50,))
- arr_sample_weights = np.random.random((50,))
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(2,)))
+ model.compile(
+ loss='mse',
+ optimizer='sgd',
+ metrics=['mae', metrics_module.CategoricalAccuracy()])
- def custom_generator():
- batch_size = 10
- num_samples = 50
- while True:
- batch_index = np.random.randint(0, num_samples - batch_size)
- start = batch_index
- end = start + batch_size
- x = arr_data[start: end]
- y = arr_labels[start: end]
- w = arr_sample_weights[start: end]
- yield x, y, w
+ model.fit_generator(custom_generator(mode=3),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.fit_generator(custom_generator(mode=3),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False,
+ validation_data=custom_generator(mode=3),
+ validation_steps=10)
+ model.predict_generator(custom_generator(mode=3),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.evaluate_generator(custom_generator(mode=3),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
- with self.cached_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(2,)))
- model.compile(
- loss='mse',
- optimizer='sgd',
- metrics=['mae', metrics_module.CategoricalAccuracy()])
+ def test_generator_methods_invalid_use_case(self):
- model.fit_generator(custom_generator(),
+ def invalid_generator():
+ while 1:
+ yield 0
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(2,)))
+ model.compile(loss='mse', optimizer='sgd')
+
+ with self.assertRaises(ValueError):
+ model.fit_generator(invalid_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False)
+ with self.assertRaises(ValueError):
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_queue_size=10,
use_multiprocessing=False,
- validation_data=custom_generator(),
+ validation_data=invalid_generator(),
validation_steps=10)
- model.predict_generator(custom_generator(),
+ with self.assertRaises(AttributeError):
+ model.predict_generator(invalid_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
- model.evaluate_generator(custom_generator(),
+ with self.assertRaises(ValueError):
+ model.evaluate_generator(invalid_generator(),
steps=5,
max_queue_size=10,
use_multiprocessing=False)
- def test_generator_methods_invalid_use_case(self):
+ def test_generator_input_to_fit_eval_predict(self):
+ val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
- def custom_generator():
- while 1:
- yield 0
+ def ones_generator():
+ while True:
+ yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
- with self.cached_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(2,)))
- model.compile(loss='mse', optimizer='sgd')
+ inputs = keras.layers.Input(shape=(10,))
+ x = keras.layers.Dense(10, activation='relu')(inputs)
+ outputs = keras.layers.Dense(1, activation='sigmoid')(x)
+ model = keras.Model(inputs, outputs)
- with self.assertRaises(ValueError):
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
- max_queue_size=10,
- use_multiprocessing=False)
- with self.assertRaises(ValueError):
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
- max_queue_size=10,
- use_multiprocessing=False,
- validation_data=custom_generator(),
- validation_steps=10)
- with self.assertRaises(AttributeError):
- model.predict_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- use_multiprocessing=False)
- with self.assertRaises(ValueError):
- model.evaluate_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- use_multiprocessing=False)
+ model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
+ model.fit(
+ ones_generator(),
+ steps_per_epoch=2,
+ validation_data=val_data,
+ epochs=2)
+ model.evaluate(ones_generator(), steps=2)
+ model.predict(ones_generator(), steps=2)
+
+
+@tf_test_util.run_all_in_graph_and_eager_modes
+class TestGeneratorMethodsWithSequences(test.TestCase):
def test_training_with_sequences(self):
class DummySequence(keras.utils.Sequence):
def __getitem__(self, idx):
- return np.zeros([10, 2]), np.ones([10])
+ return np.zeros([10, 2]), np.ones([10, 4])
def __len__(self):
return 10
- arr_data = np.random.random((50, 2))
- arr_labels = np.random.random((50,))
- arr_sample_weights = np.random.random((50,))
-
- def custom_generator():
- batch_size = 10
- num_samples = 50
- while True:
- batch_index = np.random.randint(0, num_samples - batch_size)
- start = batch_index
- end = start + batch_size
- x = arr_data[start: end]
- y = arr_labels[start: end]
- w = arr_sample_weights[start: end]
- yield x, y, w
-
- with self.cached_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(2,)))
- model.compile(loss='mse', optimizer='sgd')
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(4, input_shape=(2,)))
+ model.compile(loss='mse', optimizer='sgd')
model.fit_generator(DummySequence(),
steps_per_epoch=10,
@@ -251,29 +300,6 @@
workers=0,
use_multiprocessing=False)
- @tf_test_util.run_in_graph_and_eager_modes
- def test_generator_input_to_fit_eval_predict(self):
- val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
-
- def custom_generator():
- while True:
- yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
-
- inputs = keras.layers.Input(shape=(10,))
- x = keras.layers.Dense(10, activation='relu')(inputs)
- outputs = keras.layers.Dense(1, activation='sigmoid')(x)
- model = keras.Model(inputs, outputs)
-
- model.compile(RMSPropOptimizer(0.001), 'binary_crossentropy')
- model.fit(
- custom_generator(),
- steps_per_epoch=2,
- validation_data=val_data,
- epochs=2)
- model.evaluate(custom_generator(), steps=2)
- model.predict(custom_generator(), steps=2)
-
- @tf_test_util.run_in_graph_and_eager_modes
def test_sequence_input_to_fit_eval_predict(self):
val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
diff --git a/tensorflow/python/keras/engine/training_gpu_test.py b/tensorflow/python/keras/engine/training_gpu_test.py
index 596d085..45dcfe4 100644
--- a/tensorflow/python/keras/engine/training_gpu_test.py
+++ b/tensorflow/python/keras/engine/training_gpu_test.py
@@ -69,7 +69,7 @@
return simple_model
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
losses_to_test = ['sparse_categorical_crossentropy',
'categorical_crossentropy', 'binary_crossentropy']
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 1009ef7..97dfe6d 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -600,6 +600,34 @@
np.ones((10, 10), 'float32'), np.ones((10, 1), 'float32'), epochs=10)
self.assertTrue('Epoch 5/10' in mock_stdout.getvalue())
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_training_with_loss_instance(self):
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+ loss_weights = [1., 0.5]
+ model.compile(
+ RMSPropOptimizer(learning_rate=0.001),
+ loss=keras.losses.MeanSquaredError(),
+ metrics=[metrics_module.CategoricalAccuracy(), 'mae'],
+ loss_weights=loss_weights)
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+
+ model.fit([input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ batch_size=5)
+
class TestExceptionsAndWarnings(test.TestCase):
@@ -1918,7 +1946,7 @@
w = np.array([[3., 4.], [1., 2.]])
outs = model.evaluate(x, y, sample_weight=w)
- self.assertArrayNear(outs, [0.3, 0.7, 0.3], .001)
+ self.assertArrayNear(outs, [0.75, 0.7, 0.3], .001)
# Verify that metric value is same with arbitrary weights and batch size.
x = np.random.random((50, 2, 1))
@@ -1988,7 +2016,7 @@
# verify that masking is combined with sample weights.
w = np.array([3, 2, 4])
scores = model.train_on_batch(x, y, sample_weight=w)
- self.assertArrayNear(scores, [0.2, 0.8], 0.1)
+ self.assertArrayNear(scores, [0.3328, 0.8], 0.001)
def test_add_metric_with_tensor_on_model_in_graph_mode(self):
with self.cached_session():
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 8669daf..347582a 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -35,6 +35,7 @@
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import weights_broadcast_ops
@@ -58,10 +59,10 @@
def _nested_all(data, cond_func):
"""Checks if all elements in a nested structure satisfy cond_func."""
if isinstance(data, (tuple, list)):
- return all([_nested_all(nested_data, cond_func) for nested_data in data])
+ return all(_nested_all(nested_data, cond_func) for nested_data in data)
elif isinstance(data, dict):
return all(
- [_nested_all(nested_data, cond_func) for nested_data in data.values()])
+ _nested_all(nested_data, cond_func) for nested_data in data.values())
else:
return cond_func(data)
@@ -69,7 +70,7 @@
def _nested_any(data, cond_func):
"""Checks if any nested_elements in a nested structure satisfy cond_func."""
if isinstance(data, (tuple, list)):
- return any([_nested_any(nested_data, cond_func) for nested_data in data])
+ return any(_nested_any(nested_data, cond_func) for nested_data in data)
elif isinstance(data, dict):
return any(
[_nested_any(nested_data, cond_func) for nested_data in data.values()])
@@ -632,15 +633,14 @@
weights = mask
else:
# Update dimensions of weights to match with mask if possible.
- mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
- mask, None, weights)
+ mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights)
weights *= mask
# Apply sample weighting.
if weights is not None:
# Update dimensions of weights to match with values if possible.
- score_array, _, weights = metrics_module.squeeze_or_expand_dimensions(
+ score_array, _, weights = squeeze_or_expand_dimensions(
score_array, None, weights)
try:
# Broadcast weights if possible.
@@ -838,12 +838,22 @@
return metric_fn(y_true, y_pred, sample_weight=mask)
# Update dimensions of weights to match with mask.
- mask, _, weights = metrics_module.squeeze_or_expand_dimensions(
- mask, None, weights)
+ mask, _, weights = squeeze_or_expand_dimensions(mask, None, weights)
weights *= mask
return metric_fn(y_true, y_pred, sample_weight=weights)
+def get_loss_function(loss):
+ """Returns the loss function corresponding to the given loss input."""
+ if loss is None or isinstance(loss, losses.Loss):
+ return loss
+
+ # TODO(psv): After we have added all V2 losses, update this function.
+ if loss in ['mse', 'MSE', 'mean_squared_error']:
+ return losses.MeanSquaredError()
+ return losses.get(loss)
+
+
def validate_iterator_input(x, y, sample_weight, validation_split=None):
"""Validates user input arguments when a dataset iterator is passed.
diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py
index 25ca9e6..3c0f73b 100644
--- a/tensorflow/python/keras/integration_test.py
+++ b/tensorflow/python/keras/integration_test.py
@@ -26,7 +26,6 @@
from tensorflow.python.layers import core as tf_core_layers
from tensorflow.python.ops import nn
from tensorflow.python.ops import rnn_cell
-from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
@@ -313,15 +312,6 @@
verbose=0)
self.assertGreater(history.history['val_acc'][-1], 0.7)
- def test_regularizers_with_get_variable(self):
- # Test case for GitHub issue 22470.
- with self.cached_session():
- v = variable_scope.get_variable(
- 'v',
- shape=[4, 4],
- initializer=keras.initializers.glorot_uniform(),
- regularizer=keras.regularizers.l2(0.))
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index 7268040..49990b6b 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -22,7 +22,7 @@
# pylint: disable=g-bad-import-order
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
-from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
# Advanced activations.
diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py
index a2385df..35ac783 100644
--- a/tensorflow/python/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/layers/advanced_activations.py
@@ -22,8 +22,8 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index d1b03b8..6564d6e 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -26,8 +26,8 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
# imports for backwards namespace compatibility
# pylint: disable=unused-import
from tensorflow.python.keras.layers.pooling import AveragePooling1D
diff --git a/tensorflow/python/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/layers/convolutional_recurrent.py
index 1005421..cf3861d 100644
--- a/tensorflow/python/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/layers/convolutional_recurrent.py
@@ -26,8 +26,8 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.layers.recurrent import _generate_dropout_mask
from tensorflow.python.keras.layers.recurrent import _standardize_args
from tensorflow.python.keras.layers.recurrent import RNN
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 8031272..56dd705 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -34,8 +34,8 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py
index beacdf2..81f2928 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent.py
@@ -25,7 +25,7 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine.base_layer import InputSpec
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.layers.recurrent import RNN
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_cudnn_rnn_ops
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index cc93364..1f195f3 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -31,64 +31,76 @@
from tensorflow.python.training.rmsprop import RMSPropOptimizer
+@test_util.run_all_in_graph_and_eager_modes
class CuDNNTest(test.TestCase, parameterized.TestCase):
- @test_util.run_in_graph_and_eager_modes
def test_cudnn_rnn_basics(self):
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- input_size = 10
- timesteps = 6
- units = 2
- num_samples = 32
- for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]:
- for return_sequences in [True, False]:
- with keras.utils.CustomObjectScope(
- {'keras.layers.CuDNNGRU': keras.layers.CuDNNGRU,
- 'keras.layers.CuDNNLSTM': keras.layers.CuDNNLSTM}):
- testing_utils.layer_test(
- layer_class,
- kwargs={'units': units,
- 'return_sequences': return_sequences},
- input_shape=(num_samples, timesteps, input_size))
- for go_backwards in [True, False]:
- with keras.utils.CustomObjectScope(
- {'keras.layers.CuDNNGRU': keras.layers.CuDNNGRU,
- 'keras.layers.CuDNNLSTM': keras.layers.CuDNNLSTM}):
- testing_utils.layer_test(
- layer_class,
- kwargs={'units': units,
- 'go_backwards': go_backwards},
- input_shape=(num_samples, timesteps, input_size))
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- @test_util.run_in_graph_and_eager_modes
+ with test_util.use_gpu():
+ input_size = 10
+ timesteps = 6
+ units = 2
+ num_samples = 32
+ for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]:
+ for return_sequences in [True, False]:
+ with keras.utils.CustomObjectScope({
+ 'keras.layers.CuDNNGRU': keras.layers.CuDNNGRU,
+ 'keras.layers.CuDNNLSTM': keras.layers.CuDNNLSTM
+ }):
+ testing_utils.layer_test(
+ layer_class,
+ kwargs={
+ 'units': units,
+ 'return_sequences': return_sequences
+ },
+ input_shape=(num_samples, timesteps, input_size))
+ for go_backwards in [True, False]:
+ with keras.utils.CustomObjectScope({
+ 'keras.layers.CuDNNGRU': keras.layers.CuDNNGRU,
+ 'keras.layers.CuDNNLSTM': keras.layers.CuDNNLSTM
+ }):
+ testing_utils.layer_test(
+ layer_class,
+ kwargs={
+ 'units': units,
+ 'go_backwards': go_backwards
+ },
+ input_shape=(num_samples, timesteps, input_size))
+
def test_trainability(self):
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- input_size = 10
- units = 2
- for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]:
- layer = layer_class(units)
- layer.build((None, None, input_size))
- self.assertEqual(len(layer.weights), 3)
- self.assertEqual(len(layer.trainable_weights), 3)
- self.assertEqual(len(layer.non_trainable_weights), 0)
- layer.trainable = False
- self.assertEqual(len(layer.weights), 3)
- self.assertEqual(len(layer.non_trainable_weights), 3)
- self.assertEqual(len(layer.trainable_weights), 0)
- layer.trainable = True
- self.assertEqual(len(layer.weights), 3)
- self.assertEqual(len(layer.trainable_weights), 3)
- self.assertEqual(len(layer.non_trainable_weights), 0)
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
+
+ with test_util.use_gpu():
+ input_size = 10
+ units = 2
+ for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]:
+ layer = layer_class(units)
+ layer.build((None, None, input_size))
+ self.assertEqual(len(layer.weights), 3)
+ self.assertEqual(len(layer.trainable_weights), 3)
+ self.assertEqual(len(layer.non_trainable_weights), 0)
+ layer.trainable = False
+ self.assertEqual(len(layer.weights), 3)
+ self.assertEqual(len(layer.non_trainable_weights), 3)
+ self.assertEqual(len(layer.trainable_weights), 0)
+ layer.trainable = True
+ self.assertEqual(len(layer.weights), 3)
+ self.assertEqual(len(layer.trainable_weights), 3)
+ self.assertEqual(len(layer.non_trainable_weights), 0)
@parameterized.named_parameters(
('cudnngru', keras.layers.CuDNNGRU),
('cudnnlstm', keras.layers.CuDNNLSTM),
)
def test_regularizer(self, layer_class):
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
+
if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
input_size = 10
timesteps = 6
units = 2
@@ -119,132 +131,140 @@
('cudnnlstm', keras.layers.CuDNNLSTM),
)
def test_return_state(self, layer_class):
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- input_size = 10
- timesteps = 6
- units = 2
- num_samples = 32
- num_states = 2 if layer_class is keras.layers.CuDNNLSTM else 1
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- inputs = keras.Input(batch_shape=(num_samples, timesteps, input_size))
- layer = layer_class(units, return_state=True, stateful=True)
- outputs = layer(inputs)
- _, state = outputs[0], outputs[1:]
- self.assertEqual(len(state), num_states)
- model = keras.models.Model(inputs, state[0])
+ with test_util.use_gpu():
+ input_size = 10
+ timesteps = 6
+ units = 2
+ num_samples = 32
+ num_states = 2 if layer_class is keras.layers.CuDNNLSTM else 1
- inputs = np.random.random((num_samples, timesteps, input_size))
- state = model.predict(inputs)
- np.testing.assert_allclose(
- keras.backend.eval(layer.states[0]), state, atol=1e-4)
+ inputs = keras.Input(batch_shape=(num_samples, timesteps, input_size))
+ layer = layer_class(units, return_state=True, stateful=True)
+ outputs = layer(inputs)
+ _, state = outputs[0], outputs[1:]
+ self.assertEqual(len(state), num_states)
+ model = keras.models.Model(inputs, state[0])
+
+ inputs = np.random.random((num_samples, timesteps, input_size))
+ state = model.predict(inputs)
+ np.testing.assert_allclose(
+ keras.backend.eval(layer.states[0]), state, atol=1e-4)
@parameterized.named_parameters(
('cudnngru', keras.layers.CuDNNGRU),
('cudnnlstm', keras.layers.CuDNNLSTM),
)
def test_time_major_input(self, layer_class):
- if test.is_gpu_available(cuda_only=True):
- with self.test_session(use_gpu=True):
- input_size = 10
- timesteps = 6
- units = 2
- num_samples = 32
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- model = keras.models.Sequential()
- model.add(
- keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
- layer = layer_class(units, time_major=True, return_sequences=True)
- model.add(layer)
- model.add(
- keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
- model.compile(loss='categorical_crossentropy', optimizer='adam')
- model.fit(
- np.ones((num_samples, timesteps, input_size)),
- np.ones((num_samples, timesteps, units)))
- out = model.predict(np.ones((num_samples, timesteps, input_size)))
- self.assertEqual(out.shape, (num_samples, timesteps, units))
+ with test_util.use_gpu():
+ input_size = 10
+ timesteps = 6
+ units = 2
+ num_samples = 32
+
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
+ layer = layer_class(units, time_major=True, return_sequences=True)
+ model.add(layer)
+ model.add(
+ keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+ model.fit(
+ np.ones((num_samples, timesteps, input_size)),
+ np.ones((num_samples, timesteps, units)))
+ out = model.predict(np.ones((num_samples, timesteps, input_size)))
+ self.assertEqual(out.shape, (num_samples, timesteps, units))
@parameterized.named_parameters(
('cudnngru', keras.layers.CuDNNGRU),
('cudnnlstm', keras.layers.CuDNNLSTM),
)
def test_specify_initial_state_keras_tensor(self, layer_class):
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- input_size = 10
- timesteps = 6
- units = 2
- num_samples = 32
- num_states = 2 if layer_class is keras.layers.CuDNNLSTM else 1
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- inputs = keras.Input((timesteps, input_size))
- initial_state = [keras.Input((units,)) for _ in range(num_states)]
- layer = layer_class(units)
- if len(initial_state) == 1:
- output = layer(inputs, initial_state=initial_state[0])
- else:
- output = layer(inputs, initial_state=initial_state)
- self.assertIn(initial_state[0], layer._inbound_nodes[0].input_tensors)
+ with test_util.use_gpu():
+ input_size = 10
+ timesteps = 6
+ units = 2
+ num_samples = 32
+ num_states = 2 if layer_class is keras.layers.CuDNNLSTM else 1
- model = keras.models.Model([inputs] + initial_state, output)
- model.compile(loss='categorical_crossentropy', optimizer='adam')
+ inputs = keras.Input((timesteps, input_size))
+ initial_state = [keras.Input((units,)) for _ in range(num_states)]
+ layer = layer_class(units)
+ if len(initial_state) == 1:
+ output = layer(inputs, initial_state=initial_state[0])
+ else:
+ output = layer(inputs, initial_state=initial_state)
+ self.assertIn(initial_state[0], layer._inbound_nodes[0].input_tensors)
- inputs = np.random.random((num_samples, timesteps, input_size))
- initial_state = [
- np.random.random((num_samples, units)) for _ in range(num_states)
- ]
- targets = np.random.random((num_samples, units))
- model.fit([inputs] + initial_state, targets)
+ model = keras.models.Model([inputs] + initial_state, output)
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+
+ inputs = np.random.random((num_samples, timesteps, input_size))
+ initial_state = [
+ np.random.random((num_samples, units)) for _ in range(num_states)
+ ]
+ targets = np.random.random((num_samples, units))
+ model.fit([inputs] + initial_state, targets)
@parameterized.named_parameters(
('cudnngru', keras.layers.CuDNNGRU),
('cudnnlstm', keras.layers.CuDNNLSTM),
)
def test_statefulness(self, layer_class):
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- input_size = 10
- timesteps = 6
- units = 2
- num_samples = 32
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- model = keras.models.Sequential()
- model.add(
- keras.layers.Embedding(
- 10,
- input_size,
- input_length=timesteps,
- batch_input_shape=(num_samples, timesteps)))
- layer = layer_class(
- units, return_sequences=False, stateful=True, weights=None)
- model.add(layer)
- model.compile(optimizer='sgd', loss='mse')
- out1 = model.predict(np.ones((num_samples, timesteps)))
- self.assertEqual(out1.shape, (num_samples, units))
+ with test_util.use_gpu():
+ input_size = 10
+ timesteps = 6
+ units = 2
+ num_samples = 32
- # train once so that the states change
- model.train_on_batch(
- np.ones((num_samples, timesteps)), np.ones((num_samples, units)))
- out2 = model.predict(np.ones((num_samples, timesteps)))
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Embedding(
+ 10,
+ input_size,
+ input_length=timesteps,
+ batch_input_shape=(num_samples, timesteps)))
+ layer = layer_class(
+ units, return_sequences=False, stateful=True, weights=None)
+ model.add(layer)
+ model.compile(optimizer='sgd', loss='mse')
+ out1 = model.predict(np.ones((num_samples, timesteps)))
+ self.assertEqual(out1.shape, (num_samples, units))
- # if the state is not reset, output should be different
- self.assertNotEqual(out1.max(), out2.max())
+ # train once so that the states change
+ model.train_on_batch(
+ np.ones((num_samples, timesteps)), np.ones((num_samples, units)))
+ out2 = model.predict(np.ones((num_samples, timesteps)))
- # check that output changes after states are reset
- # (even though the model itself didn't change)
- layer.reset_states()
- out3 = model.predict(np.ones((num_samples, timesteps)))
- self.assertNotEqual(out2.max(), out3.max())
+ # if the state is not reset, output should be different
+ self.assertNotEqual(out1.max(), out2.max())
- # check that container-level reset_states() works
- model.reset_states()
- out4 = model.predict(np.ones((num_samples, timesteps)))
- self.assertAllClose(out3, out4, atol=1e-5)
+ # check that output changes after states are reset
+ # (even though the model itself didn't change)
+ layer.reset_states()
+ out3 = model.predict(np.ones((num_samples, timesteps)))
+ self.assertNotEqual(out2.max(), out3.max())
- # check that the call to `predict` updated the states
- out5 = model.predict(np.ones((num_samples, timesteps)))
- self.assertNotEqual(out4.max(), out5.max())
+ # check that container-level reset_states() works
+ model.reset_states()
+ out4 = model.predict(np.ones((num_samples, timesteps)))
+ self.assertAllClose(out3, out4, atol=1e-5)
+
+ # check that the call to `predict` updated the states
+ out5 = model.predict(np.ones((num_samples, timesteps)))
+ self.assertNotEqual(out4.max(), out5.max())
@parameterized.named_parameters(
*test_util.generate_combinations_with_testcase_name(
@@ -254,49 +274,51 @@
def test_load_weights_between_noncudnn_rnn(self, rnn_type, to_cudnn,
bidirectional, implementation,
model_nest_level, model_type):
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- input_size = 10
- timesteps = 6
- input_shape = (timesteps, input_size)
- units = 2
- num_samples = 32
- inputs = np.random.random((num_samples, timesteps, input_size))
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- rnn_layer_kwargs = {
- 'recurrent_activation': 'sigmoid',
- # ensure biases are non-zero and properly converted
- 'bias_initializer': 'random_uniform',
- 'implementation': implementation
- }
- if rnn_type == 'LSTM':
- rnn_layer_class = keras.layers.LSTM
- cudnn_rnn_layer_class = keras.layers.CuDNNLSTM
- else:
- rnn_layer_class = keras.layers.GRU
- cudnn_rnn_layer_class = keras.layers.CuDNNGRU
- rnn_layer_kwargs['reset_after'] = True
+ with test_util.use_gpu():
+ input_size = 10
+ timesteps = 6
+ input_shape = (timesteps, input_size)
+ units = 2
+ num_samples = 32
+ inputs = np.random.random((num_samples, timesteps, input_size))
- layer = rnn_layer_class(units, **rnn_layer_kwargs)
- if bidirectional:
- layer = keras.layers.Bidirectional(layer)
+ rnn_layer_kwargs = {
+ 'recurrent_activation': 'sigmoid',
+ # ensure biases are non-zero and properly converted
+ 'bias_initializer': 'random_uniform',
+ 'implementation': implementation
+ }
+ if rnn_type == 'LSTM':
+ rnn_layer_class = keras.layers.LSTM
+ cudnn_rnn_layer_class = keras.layers.CuDNNLSTM
+ else:
+ rnn_layer_class = keras.layers.GRU
+ cudnn_rnn_layer_class = keras.layers.CuDNNGRU
+ rnn_layer_kwargs['reset_after'] = True
- cudnn_layer = cudnn_rnn_layer_class(units)
- if bidirectional:
- cudnn_layer = keras.layers.Bidirectional(cudnn_layer)
+ layer = rnn_layer_class(units, **rnn_layer_kwargs)
+ if bidirectional:
+ layer = keras.layers.Bidirectional(layer)
- model = self._make_nested_model(input_shape, layer, model_nest_level,
- model_type)
- cudnn_model = self._make_nested_model(input_shape, cudnn_layer,
- model_nest_level, model_type)
+ cudnn_layer = cudnn_rnn_layer_class(units)
+ if bidirectional:
+ cudnn_layer = keras.layers.Bidirectional(cudnn_layer)
- if to_cudnn:
- self._convert_model_weights(model, cudnn_model)
- else:
- self._convert_model_weights(cudnn_model, model)
+ model = self._make_nested_model(input_shape, layer, model_nest_level,
+ model_type)
+ cudnn_model = self._make_nested_model(input_shape, cudnn_layer,
+ model_nest_level, model_type)
- self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs),
- atol=1e-4)
+ if to_cudnn:
+ self._convert_model_weights(model, cudnn_model)
+ else:
+ self._convert_model_weights(cudnn_model, model)
+
+ self.assertAllClose(
+ model.predict(inputs), cudnn_model.predict(inputs), atol=1e-4)
def _make_nested_model(self, input_shape, layer, level=1, model_type='func'):
# example: make_nested_seq_model((1,), Dense(10), level=2).summary()
@@ -334,149 +356,145 @@
to_cudnn):
# Similar test as test_load_weights_between_noncudnn_rnn() but has different
# rank of input due to usage of TimeDistributed. Issue: #10356.
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- input_size = 10
- steps = 6
- timesteps = 6
- input_shape = (timesteps, steps, input_size)
- units = 2
- num_samples = 32
- inputs = np.random.random((num_samples, timesteps, steps, input_size))
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- rnn_layer_kwargs = {
- 'recurrent_activation': 'sigmoid',
- # ensure biases are non-zero and properly converted
- 'bias_initializer': 'random_uniform',
- }
- if rnn_type == 'LSTM':
- rnn_layer_class = keras.layers.LSTM
- cudnn_rnn_layer_class = keras.layers.CuDNNLSTM
- else:
- rnn_layer_class = keras.layers.GRU
- cudnn_rnn_layer_class = keras.layers.CuDNNGRU
- rnn_layer_kwargs['reset_after'] = True
+ with test_util.use_gpu():
+ input_size = 10
+ steps = 6
+ timesteps = 6
+ input_shape = (timesteps, steps, input_size)
+ units = 2
+ num_samples = 32
+ inputs = np.random.random((num_samples, timesteps, steps, input_size))
- layer = rnn_layer_class(units, **rnn_layer_kwargs)
- layer = keras.layers.TimeDistributed(layer)
+ rnn_layer_kwargs = {
+ 'recurrent_activation': 'sigmoid',
+ # ensure biases are non-zero and properly converted
+ 'bias_initializer': 'random_uniform',
+ }
+ if rnn_type == 'LSTM':
+ rnn_layer_class = keras.layers.LSTM
+ cudnn_rnn_layer_class = keras.layers.CuDNNLSTM
+ else:
+ rnn_layer_class = keras.layers.GRU
+ cudnn_rnn_layer_class = keras.layers.CuDNNGRU
+ rnn_layer_kwargs['reset_after'] = True
- cudnn_layer = cudnn_rnn_layer_class(units)
- cudnn_layer = keras.layers.TimeDistributed(cudnn_layer)
+ layer = rnn_layer_class(units, **rnn_layer_kwargs)
+ layer = keras.layers.TimeDistributed(layer)
- model = self._make_nested_model(input_shape, layer)
- cudnn_model = self._make_nested_model(input_shape, cudnn_layer)
+ cudnn_layer = cudnn_rnn_layer_class(units)
+ cudnn_layer = keras.layers.TimeDistributed(cudnn_layer)
- if to_cudnn:
- self._convert_model_weights(model, cudnn_model)
- else:
- self._convert_model_weights(cudnn_model, model)
+ model = self._make_nested_model(input_shape, layer)
+ cudnn_model = self._make_nested_model(input_shape, cudnn_layer)
- self.assertAllClose(model.predict(inputs), cudnn_model.predict(inputs),
- atol=1e-4)
+ if to_cudnn:
+ self._convert_model_weights(model, cudnn_model)
+ else:
+ self._convert_model_weights(cudnn_model, model)
- @test_util.run_in_graph_and_eager_modes
+ self.assertAllClose(
+ model.predict(inputs), cudnn_model.predict(inputs), atol=1e-4)
+
def test_cudnnrnn_bidirectional(self):
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- rnn = keras.layers.CuDNNGRU
- samples = 2
- dim = 2
- timesteps = 2
- output_dim = 2
- mode = 'concat'
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- x = np.random.random((samples, timesteps, dim))
- target_dim = 2 * output_dim if mode == 'concat' else output_dim
- y = np.random.random((samples, target_dim))
+ with test_util.use_gpu():
+ rnn = keras.layers.CuDNNGRU
+ samples = 2
+ dim = 2
+ timesteps = 2
+ output_dim = 2
+ mode = 'concat'
- # test with Sequential model
- model = keras.Sequential()
- model.add(
- keras.layers.Bidirectional(
- rnn(output_dim), merge_mode=mode, input_shape=(None, dim)))
- model.compile(
- loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
- model.fit(x, y, epochs=1, batch_size=1)
+ x = np.random.random((samples, timesteps, dim))
+ target_dim = 2 * output_dim if mode == 'concat' else output_dim
+ y = np.random.random((samples, target_dim))
- # test config
- model.get_config()
- model = keras.models.model_from_json(model.to_json())
- model.summary()
+ # test with Sequential model
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Bidirectional(
+ rnn(output_dim), merge_mode=mode, input_shape=(None, dim)))
+ model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
+ model.fit(x, y, epochs=1, batch_size=1)
- # test stacked bidirectional layers
- model = keras.Sequential()
- model.add(
- keras.layers.Bidirectional(
- rnn(output_dim, return_sequences=True),
- merge_mode=mode,
- input_shape=(None, dim)))
- model.add(keras.layers.Bidirectional(rnn(output_dim), merge_mode=mode))
- model.compile(
- loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
- model.fit(x, y, epochs=1, batch_size=1)
+ # test config
+ model.get_config()
+ model = keras.models.model_from_json(model.to_json())
+ model.summary()
- # test with functional API
- inputs = keras.Input((timesteps, dim))
- outputs = keras.layers.Bidirectional(
- rnn(output_dim), merge_mode=mode)(
- inputs)
- model = keras.Model(inputs, outputs)
- model.compile(
- loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
- model.fit(x, y, epochs=1, batch_size=1)
+ # test stacked bidirectional layers
+ model = keras.Sequential()
+ model.add(
+ keras.layers.Bidirectional(
+ rnn(output_dim, return_sequences=True),
+ merge_mode=mode,
+ input_shape=(None, dim)))
+ model.add(keras.layers.Bidirectional(rnn(output_dim), merge_mode=mode))
+ model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
+ model.fit(x, y, epochs=1, batch_size=1)
- # Bidirectional and stateful
- inputs = keras.Input(batch_shape=(1, timesteps, dim))
- outputs = keras.layers.Bidirectional(
- rnn(output_dim, stateful=True), merge_mode=mode)(
- inputs)
- model = keras.Model(inputs, outputs)
- model.compile(
- loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
- model.fit(x, y, epochs=1, batch_size=1)
+ # test with functional API
+ inputs = keras.Input((timesteps, dim))
+ outputs = keras.layers.Bidirectional(
+ rnn(output_dim), merge_mode=mode)(
+ inputs)
+ model = keras.Model(inputs, outputs)
+ model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
+ model.fit(x, y, epochs=1, batch_size=1)
+
+ # Bidirectional and stateful
+ inputs = keras.Input(batch_shape=(1, timesteps, dim))
+ outputs = keras.layers.Bidirectional(
+ rnn(output_dim, stateful=True), merge_mode=mode)(
+ inputs)
+ model = keras.Model(inputs, outputs)
+ model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001))
+ model.fit(x, y, epochs=1, batch_size=1)
def test_preprocess_weights_for_loading_gru_incompatible(self):
"""Test loading weights between incompatible layers.
Should fail fast with an exception.
"""
- if test.is_gpu_available(cuda_only=True):
- with self.session(use_gpu=True):
- input_shape = (3, 5)
+ if not test.is_gpu_available(cuda_only=True):
+ self.skipTest('No CUDA GPU available')
- def gru(cudnn=False, **kwargs):
- layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRU
- return layer_class(2, input_shape=input_shape, **kwargs)
+ with test_util.use_gpu():
+ input_shape = (3, 5)
- def get_layer_weights(layer):
- layer.build(input_shape=input_shape)
- return layer.get_weights()
+ def gru(cudnn=False, **kwargs):
+ layer_class = keras.layers.CuDNNGRU if cudnn else keras.layers.GRU
+ return layer_class(2, input_shape=input_shape, **kwargs)
- def assert_not_compatible(src, dest, message):
- with self.assertRaises(ValueError) as ex:
- keras.engine.saving.preprocess_weights_for_loading(
- dest,
- get_layer_weights(src))
- self.assertIn(message, str(ex.exception))
+ def get_layer_weights(layer):
+ layer.build(input_shape=input_shape)
+ return layer.get_weights()
- assert_not_compatible(
- gru(),
- gru(cudnn=True),
- 'GRU(reset_after=False) is not compatible with CuDNNGRU')
- assert_not_compatible(
- gru(cudnn=True),
- gru(),
- 'CuDNNGRU is not compatible with GRU(reset_after=False)')
- assert_not_compatible(
- gru(),
- gru(reset_after=True),
- 'GRU(reset_after=False) is not compatible with '
- 'GRU(reset_after=True)')
- assert_not_compatible(
- gru(reset_after=True),
- gru(),
- 'GRU(reset_after=True) is not compatible with '
- 'GRU(reset_after=False)')
+ def assert_not_compatible(src, dest, message):
+ with self.assertRaises(ValueError) as ex:
+ keras.engine.saving.preprocess_weights_for_loading(
+ dest, get_layer_weights(src))
+ self.assertIn(message, str(ex.exception))
+
+ assert_not_compatible(
+ gru(), gru(cudnn=True),
+ 'GRU(reset_after=False) is not compatible with CuDNNGRU')
+ assert_not_compatible(
+ gru(cudnn=True), gru(),
+ 'CuDNNGRU is not compatible with GRU(reset_after=False)')
+ assert_not_compatible(
+ gru(), gru(reset_after=True),
+ 'GRU(reset_after=False) is not compatible with '
+ 'GRU(reset_after=True)')
+ assert_not_compatible(
+ gru(reset_after=True), gru(),
+ 'GRU(reset_after=True) is not compatible with '
+ 'GRU(reset_after=False)')
if __name__ == '__main__':
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index 28d8ef2..e8a8575 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -45,11 +45,11 @@
model = Sequential()
model.add(Embedding(1000, 64, input_length=10))
# the model will take as input an integer matrix of size (batch,
- input_length).
+ # input_length).
# the largest integer (i.e. word index) in the input should be no larger
- than 999 (vocabulary size).
+ # than 999 (vocabulary size).
# now model.output_shape == (None, 10, 64), where None is the batch
- dimension.
+ # dimension.
input_array = np.random.randint(1000, size=(32, 10))
diff --git a/tensorflow/python/keras/layers/local.py b/tensorflow/python/keras/layers/local.py
index 33d09a1..d2c4aaa 100644
--- a/tensorflow/python/keras/layers/local.py
+++ b/tensorflow/python/keras/layers/local.py
@@ -23,8 +23,8 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.util.tf_export import tf_export
diff --git a/tensorflow/python/keras/layers/merge.py b/tensorflow/python/keras/layers/merge.py
index f295af3..45e705c 100644
--- a/tensorflow/python/keras/layers/merge.py
+++ b/tensorflow/python/keras/layers/merge.py
@@ -212,7 +212,7 @@
if len(mask) != len(inputs):
raise ValueError('The lists `inputs` and `mask` '
'should have the same length.')
- if all([m is None for m in mask]):
+ if all(m is None for m in mask):
return None
masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
@@ -378,7 +378,7 @@
if not isinstance(input_shape, list) or len(input_shape) < 2:
raise ValueError('A `Concatenate` layer should be called '
'on a list of at least 2 inputs')
- if all([shape is None for shape in input_shape]):
+ if all(shape is None for shape in input_shape):
return
reduced_inputs_shapes = [list(shape) for shape in input_shape]
shape_set = set()
@@ -418,7 +418,7 @@
if len(mask) != len(inputs):
raise ValueError('The lists `inputs` and `mask` '
'should have the same length.')
- if all([m is None for m in mask]):
+ if all(m is None for m in mask):
return None
# Make a list of masks while making sure
# the dimensionality of each mask
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 7a91693..d958497 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -26,8 +27,8 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@@ -40,8 +41,8 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export('keras.layers.BatchNormalization')
-class BatchNormalization(Layer):
+@tf_export('keras.layers.BatchNormalization', v1=[])
+class BatchNormalizationV2(Layer):
"""Batch normalization layer (Ioffe and Szegedy, 2014).
Normalize the activations of the previous layer at each batch,
@@ -84,8 +85,10 @@
and should be neither too small (which would add noise) nor too large
(which would give stale estimates). Note that `momentum` is still applied
to get the means and variances for inference.
- fused: if `None` or `True`, use a faster, fused implementation if possible.
- If `False`, use the system recommended implementation.
+ fused: if `True`, use a faster, fused implementation, or raise a ValueError
+ if the fused implementation cannot be used. If `None`, use the faster
+ implementation if possible. If False, do not used the fused
+ implementation.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`,
@@ -120,6 +123,9 @@
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""
+ # The BatchNormalizationV1 subclass sets this to False to use the V1 behavior.
+ _USE_V2_BEHAVIOR = True
+
def __init__(self,
axis=-1,
momentum=0.99,
@@ -143,12 +149,15 @@
adjustment=None,
name=None,
**kwargs):
- super(BatchNormalization, self).__init__(
+ super(BatchNormalizationV2, self).__init__(
name=name, trainable=trainable, **kwargs)
if isinstance(axis, list):
self.axis = axis[:]
- else:
+ elif isinstance(axis, int):
self.axis = axis
+ else:
+ raise TypeError('axis must be int or list, type given: %s'
+ % type(self.axis))
self.momentum = momentum
self.epsilon = epsilon
self.center = center
@@ -165,7 +174,14 @@
self.renorm = renorm
self.virtual_batch_size = virtual_batch_size
self.adjustment = adjustment
- if fused is None:
+ if self._USE_V2_BEHAVIOR:
+ if fused:
+ self._raise_if_fused_cannot_be_used()
+ # We leave fused as None if self._fused_can_be_used()==True, since we
+ # still may set it to False in self.build() if the input rank is not 4.
+ elif fused is None and not self._fused_can_be_used():
+ fused = False
+ elif fused is None:
fused = True
self.supports_masking = True
@@ -181,6 +197,38 @@
self.renorm_clipping = renorm_clipping
self.renorm_momentum = renorm_momentum
+ def _raise_if_fused_cannot_be_used(self):
+ """Raises a ValueError if fused implementation cannot be used.
+
+ In addition to the checks done in this function, the input tensors rank must
+ be 4. The input rank check can only be done once the input shape is known.
+ """
+ # Currently fused batch norm doesn't support renorm. It also only supports a
+ # channel dimension on axis 1 or 3, when no virtual batch size or adjustment
+ # is used.
+ if self.renorm:
+ raise ValueError('Passing both fused=True and renorm=True is '
+ 'unsupported')
+ axis = [self.axis] if isinstance(self.axis, int) else self.axis
+ # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
+ # input rank is required to be 4 (which is checked later).
+ if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
+ raise ValueError('Passing fused=True is only supported when axis is 1 '
+ 'or 3')
+ if self.virtual_batch_size is not None:
+ raise ValueError('Passing fused=True is unsupported when '
+ 'virtual_batch_size is specified.')
+ if self.adjustment is not None:
+ raise ValueError('Passing fused=True is unsupported when '
+ 'adjustment is specified.')
+
+ def _fused_can_be_used(self):
+ try:
+ self._raise_if_fused_cannot_be_used()
+ return True
+ except ValueError:
+ return False
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
@@ -191,10 +239,6 @@
if isinstance(self.axis, int):
self.axis = [self.axis]
- if not isinstance(self.axis, list):
- raise TypeError('axis must be int or list, type given: %s'
- % type(self.axis))
-
for idx, x in enumerate(self.axis):
if x < 0:
self.axis[idx] = ndims + x
@@ -219,16 +263,18 @@
raise ValueError('When using virtual_batch_size, adjustment cannot '
'be specified')
- if self.fused:
- # Currently fused batch norm doesn't support renorm. It also only supports
- # an input tensor of rank 4 and a channel dimension on axis 1 or 3.
+ if self.fused in (None, True):
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
# output back to its original shape accordingly.
- self.fused = (not self.renorm and
- ndims == 4 and
- self.axis in [[1], [3]] and
- self.virtual_batch_size is None and
- self.adjustment is None)
+ if self._USE_V2_BEHAVIOR:
+ if self.fused is None:
+ self.fused = (ndims == 4)
+ elif self.fused and ndims != 4:
+ raise ValueError('Batch normalization layers with fused=True only '
+ 'support 4D input tensors.')
+ else:
+ assert self.fused is not None
+ self.fused = (ndims == 4 and self._fused_can_be_used())
# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches. In some cases,
# it might be possible to use fused batch norm but would require reshaping
@@ -491,6 +537,9 @@
return (r, d, new_mean, new_variance)
+ def _moments(self, inputs, reduction_axes, keep_dims):
+ return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
+
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
@@ -562,7 +611,8 @@
# Some of the computations here are not necessary when training==False
# but not a constant. However, this makes the code simpler.
keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
- mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
+ mean, variance = self._moments(
+ inputs, reduction_axes, keep_dims=keep_dims)
moving_mean = self.moving_mean
moving_variance = self.moving_variance
@@ -668,5 +718,36 @@
'layer cannot be serialized and has been omitted from '
'the layer config. It will not be included when '
're-creating the layer from the saved config.')
- base_config = super(BatchNormalization, self).get_config()
+ base_config = super(BatchNormalizationV2, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+
+def _replace_in_v2_docstring(old, new):
+ string = BatchNormalizationV2.__doc__
+ if old not in string:
+ raise ValueError('Could not find following string in BatchNormalizationV2 '
+ 'docstring: "{}"'.format(old))
+ return string.replace(old, new)
+
+
+@tf_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
+class BatchNormalizationV1(BatchNormalizationV2):
+
+ __doc__ = _replace_in_v2_docstring(
+ '''
+ fused: if `True`, use a faster, fused implementation, or raise a ValueError
+ if the fused implementation cannot be used. If `None`, use the faster
+ implementation if possible. If False, do not used the fused
+ implementation.''',
+
+ '''
+ fused: if `None` or `True`, use a faster, fused implementation if possible.
+ If `False`, use the system recommended implementation.''')
+
+ _USE_V2_BEHAVIOR = False
+
+
+if tf2.enabled():
+ BatchNormalization = BatchNormalizationV2
+else:
+ BatchNormalization = BatchNormalizationV1
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index 92e4128..2f7f042 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -23,6 +23,7 @@
from tensorflow.python import keras
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.layers import normalization
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
@@ -54,6 +55,14 @@
kwargs={'scale': False,
'center': False},
input_shape=(3, 3))
+ testing_utils.layer_test(
+ normalization.BatchNormalizationV2,
+ kwargs={'fused': True},
+ input_shape=(3, 3, 3, 3))
+ testing_utils.layer_test(
+ normalization.BatchNormalizationV2,
+ kwargs={'fused': None},
+ input_shape=(3, 3, 3))
def test_batchnorm_weights(self):
layer = keras.layers.BatchNormalization(scale=False, center=False)
@@ -78,15 +87,18 @@
self.assertEqual(layer.gamma.constraint, max_norm)
self.assertEqual(layer.beta.constraint, max_norm)
- def test_batchnorm_correctness(self):
+ def _test_batchnorm_correctness(self, dtype, use_v2=True, fused=False):
model = keras.models.Sequential()
- norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
+ layer_ctor = (normalization.BatchNormalizationV2 if use_v2
+ else normalization.BatchNormalizationV1)
+ norm = layer_ctor(input_shape=(2, 2, 2), momentum=0.8, fused=fused)
model.add(norm)
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01))
# centered on 5.0, variance 10.0
- x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
+ x = (np.random.normal(loc=5.0, scale=10.0, size=(1000, 2, 2, 2))
+ .astype(dtype))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= keras.backend.eval(norm.beta)
@@ -95,23 +107,15 @@
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+ def test_batchnorm_correctness(self):
+ self._test_batchnorm_correctness(np.float32)
+ self._test_batchnorm_correctness(np.float32, fused=True)
+ self._test_batchnorm_correctness(np.float32, use_v2=False)
+
def test_batchnorm_mixed_precision(self):
- model = keras.models.Sequential()
- norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
- model.add(norm)
- model.compile(loss='mse',
- optimizer=gradient_descent.GradientDescentOptimizer(0.01))
-
- # centered on 5.0, variance 10.0
- x = np.random.normal(
- loc=5.0, scale=10.0, size=(1000, 10)).astype(np.float16)
- model.fit(x, x, epochs=4, verbose=0)
- out = model.predict(x)
- out -= keras.backend.eval(norm.beta)
- out /= keras.backend.eval(norm.gamma)
-
- np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
- np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
+ self._test_batchnorm_correctness(np.float16)
+ self._test_batchnorm_correctness(np.float16, fused=True)
+ self._test_batchnorm_correctness(np.float16, use_v2=False)
def test_batchnorm_convnet(self):
if test.is_gpu_available(cuda_only=True):
@@ -151,6 +155,77 @@
np.testing.assert_allclose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-1)
np.testing.assert_allclose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-1)
+ def test_v1_fused_attribute(self):
+ norm = normalization.BatchNormalizationV1()
+ inp = keras.layers.Input((4, 4, 4))
+ norm(inp)
+ self.assertEqual(norm.fused, True)
+
+ norm = normalization.BatchNormalizationV1(fused=False)
+ self.assertEqual(norm.fused, False)
+ inp = keras.layers.Input(shape=(4, 4, 4))
+ norm(inp)
+ self.assertEqual(norm.fused, False)
+
+ norm = normalization.BatchNormalizationV1(virtual_batch_size=2)
+ self.assertEqual(norm.fused, True)
+ inp = keras.layers.Input(shape=(2, 2, 2))
+ norm(inp)
+ self.assertEqual(norm.fused, False)
+
+ def test_v2_fused_attribute(self):
+ norm = normalization.BatchNormalizationV2()
+ self.assertEqual(norm.fused, None)
+ inp = keras.layers.Input(shape=(4, 4, 4))
+ norm(inp)
+ self.assertEqual(norm.fused, True)
+
+ norm = normalization.BatchNormalizationV2()
+ self.assertEqual(norm.fused, None)
+ inp = keras.layers.Input(shape=(4, 4))
+ norm(inp)
+ self.assertEqual(norm.fused, False)
+
+ norm = normalization.BatchNormalizationV2(virtual_batch_size=2)
+ self.assertEqual(norm.fused, False)
+ inp = keras.layers.Input(shape=(4, 4, 4))
+ norm(inp)
+ self.assertEqual(norm.fused, False)
+
+ norm = normalization.BatchNormalizationV2(fused=False)
+ self.assertEqual(norm.fused, False)
+ inp = keras.layers.Input(shape=(4, 4, 4))
+ norm(inp)
+ self.assertEqual(norm.fused, False)
+
+ norm = normalization.BatchNormalizationV2(fused=True, axis=[3])
+ self.assertEqual(norm.fused, True)
+ inp = keras.layers.Input(shape=(4, 4, 4))
+ norm(inp)
+ self.assertEqual(norm.fused, True)
+
+ with self.assertRaisesRegexp(ValueError, 'fused.*renorm'):
+ normalization.BatchNormalizationV2(fused=True, renorm=True)
+
+ with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
+ normalization.BatchNormalizationV2(fused=True, axis=2)
+
+ with self.assertRaisesRegexp(ValueError, 'fused.*when axis is 1 or 3'):
+ normalization.BatchNormalizationV2(fused=True, axis=[1, 3])
+
+ with self.assertRaisesRegexp(ValueError, 'fused.*virtual_batch_size'):
+ normalization.BatchNormalizationV2(fused=True, virtual_batch_size=2)
+
+ with self.assertRaisesRegexp(ValueError, 'fused.*adjustment'):
+ normalization.BatchNormalizationV2(fused=True,
+ adjustment=lambda _: (1, 0))
+
+ norm = normalization.BatchNormalizationV2(fused=True)
+ self.assertEqual(norm.fused, True)
+ inp = keras.layers.Input(shape=(4, 4))
+ with self.assertRaisesRegexp(ValueError, '4D input tensors'):
+ norm(inp)
+
class NormalizationLayersGraphModeOnlyTest(test.TestCase):
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index 72a9c1d..a0744cd 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -22,8 +22,8 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index d045338..5d0efc2 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -28,8 +28,8 @@
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/python/keras/layers/unified_rnn_test.py b/tensorflow/python/keras/layers/unified_rnn_test.py
index 015a079..b08ff3c 100644
--- a/tensorflow/python/keras/layers/unified_rnn_test.py
+++ b/tensorflow/python/keras/layers/unified_rnn_test.py
@@ -19,24 +19,26 @@
from __future__ import print_function
import collections
+import time
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import keras
-from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import testing_utils
-from tensorflow.python.keras.engine.base_layer import \
- InputSpec
+from tensorflow.python.keras.engine.input_spec import InputSpec
+from tensorflow.python.keras.layers.cudnn_recurrent import CuDNNLSTM
from tensorflow.python.keras.layers.recurrent import RNN
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
@@ -47,20 +49,27 @@
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import gradient_descent
class RNNTest(test.TestCase):
- def test_unifiedRNN(self):
- rewrites = rewriter_config_pb2.RewriterConfig()
- rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
- customer_optimizer = rewrites.custom_optimizers.add()
- customer_optimizer.name = 'ExperimentalImplementationSelector'
- rewrites.min_graph_nodes = -1
- graph_options = config_pb2.GraphOptions(rewrite_options=rewrites)
- config = config_pb2.ConfigProto(graph_options=graph_options)
+ rewrites = rewriter_config_pb2.RewriterConfig()
+ rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
+ customer_optimizer = rewrites.custom_optimizers.add()
+ customer_optimizer.name = 'ExperimentalImplementationSelector'
+ rewrites.min_graph_nodes = -1
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewrites)
+ config = config_pb2.ConfigProto(graph_options=graph_options)
+ def setUp(self):
+ self.config = RNNTest.config
+
+ def tearDown(self):
+ ops.reset_default_graph()
+
+ def test_unifiedRNN(self):
input_shape = 10
rnn_state_size = 8
output_shape = 8
@@ -68,13 +77,13 @@
batch = 100
epoch = 1
- with ops.Graph().as_default(), session.Session(config=config) as sess:
+ with self.cached_session(config=self.config, use_gpu=True) as sess:
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
- y_train = keras.utils.to_categorical(y_train)
+ y_train = keras.utils.to_categorical(y_train, output_shape)
layer = UnifiedLSTM(rnn_state_size)
@@ -108,14 +117,6 @@
# This test is to demonstrate the graph rewrite of grappler plugin under
# the condition that the function returns different number of internal
# states.
- rewrites = rewriter_config_pb2.RewriterConfig()
- rewrites.function_optimization = rewriter_config_pb2.RewriterConfig.OFF
- customer_optimizer = rewrites.custom_optimizers.add()
- customer_optimizer.name = 'ExperimentalImplementationSelector'
- rewrites.min_graph_nodes = -1
- graph_options = config_pb2.GraphOptions(rewrite_options=rewrites)
- config = config_pb2.ConfigProto(graph_options=graph_options)
-
input_shape = 10
rnn_state_size = 8
output_shape = 8
@@ -123,13 +124,13 @@
batch = 100
epoch = 1
- with ops.Graph().as_default(), session.Session(config=config) as sess:
+ with self.cached_session(config=self.config, use_gpu=True) as sess:
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=batch,
test_samples=0,
input_shape=(timestep, input_shape),
num_classes=output_shape)
- y_train = keras.utils.to_categorical(y_train)
+ y_train = keras.utils.to_categorical(y_train, output_shape)
layer = UnifiedLSTM(rnn_state_size)
@@ -169,11 +170,166 @@
self.assertNotEqual(existing_loss, loss_value)
existing_loss = loss_value
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def test_keras_model_with_lstm(self):
+ input_shape = 10
+ rnn_state_size = 8
+ output_shape = 8
+ timestep = 4
+ batch = 100
+ epoch = 10
+
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = keras.utils.to_categorical(y_train, output_shape)
+
+ layer = UnifiedLSTM(rnn_state_size)
+
+ inputs = keras.layers.Input(
+ shape=[timestep, input_shape], dtype=dtypes.float32)
+
+ outputs, unused_runtime = layer(inputs)
+ model = keras.models.Model(inputs, outputs)
+ model.compile('rmsprop', loss='mse')
+ model.fit(x_train, y_train, epochs=epoch)
+
+ def _measure_performance(self, test_config, model, x_train, y_train):
+ batch = test_config['batch']
+ epoch = test_config['epoch']
+ warmup_epoch = test_config['warmup_epoch']
+
+ # warm up the model
+ model.fit(x_train, y_train, batch_size=batch, epochs=warmup_epoch)
+ start_time = time.time()
+ model.fit(x_train, y_train, batch_size=batch, epochs=epoch - warmup_epoch)
+ end_time = time.time()
+ return (end_time - start_time) / (epoch - warmup_epoch)
+
+ def _time_performance_run_cudnn_lstm(self, test_config, x_train, y_train):
+ # Get the performance number for standard Cudnn LSTM
+ input_shape = test_config['input_shape']
+ rnn_state_size = test_config['rnn_state_size']
+ timestep = test_config['timestep']
+
+ cudnn_lstm_layer = CuDNNLSTM(rnn_state_size)
+ inputs = keras.layers.Input(
+ shape=[timestep, input_shape], dtype=dtypes.float32)
+
+ outputs = cudnn_lstm_layer(inputs)
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', 'mse')
+
+ sec_per_epoch = self._measure_performance(
+ test_config, model, x_train, y_train)
+ logging.info('Average performance for %s per epoch is: %s',
+ 'CuDNN LSTM', sec_per_epoch)
+ return sec_per_epoch
+
+ def _time_performance_run_unifed_lstm_gpu(
+ self, test_config, x_train, y_train):
+ # Get performance number for Unified_LSTM with grappler swap the impl
+ input_shape = test_config['input_shape']
+ rnn_state_size = test_config['rnn_state_size']
+ timestep = test_config['timestep']
+
+ layer = UnifiedLSTM(rnn_state_size)
+ inputs = keras.layers.Input(
+ shape=[timestep, input_shape], dtype=dtypes.float32)
+
+ outputs, _ = layer(inputs)
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', 'mse')
+
+ sec_per_epoch = self._measure_performance(
+ test_config, model, x_train, y_train)
+ logging.info('Average performance for %s per epoch is: %s',
+ 'Unified LSTM', sec_per_epoch)
+ return sec_per_epoch
+
+ def _time_performance_run_normal_lstm(
+ self, test_config, x_train, y_train):
+ # Get performance number for standard LSTM on GPU.
+ input_shape = test_config['input_shape']
+ rnn_state_size = test_config['rnn_state_size']
+ timestep = test_config['timestep']
+
+ layer = keras.layers.LSTM(rnn_state_size)
+ inputs = keras.layers.Input(
+ shape=[timestep, input_shape], dtype=dtypes.float32)
+
+ outputs = layer(inputs)
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', 'mse')
+
+ sec_per_epoch = self._measure_performance(
+ test_config, model, x_train, y_train)
+ logging.info('Average performance for %s per epoch is: %s',
+ 'Normal LSTM', sec_per_epoch)
+ return sec_per_epoch
+
+ @test_util.run_in_graph_and_eager_modes(config=config, use_gpu=True)
+ def test_performance_with_standard_cudnn_impl(self):
+ if not test.is_gpu_available():
+ self.skipTest('performance test will only run on GPU')
+
+ batch = 64
+ num_batch = 10
+ test_config = {
+ 'input_shape': 128,
+ 'rnn_state_size': 64,
+ 'output_shape': 64,
+ 'timestep': 50,
+ 'batch': batch,
+ 'epoch': 20,
+ # The performance for warmup epoch is ignored.
+ 'warmup_epoch': 1,
+ }
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=(batch * num_batch),
+ test_samples=0,
+ input_shape=(test_config['timestep'], test_config['input_shape']),
+ num_classes=test_config['output_shape'])
+ y_train = keras.utils.to_categorical(y_train, test_config['output_shape'])
+
+ cudnn_duration = self._time_performance_run_cudnn_lstm(
+ test_config, x_train, y_train)
+ unified_lstm_gpu_duration = self._time_performance_run_unifed_lstm_gpu(
+ test_config, x_train, y_train)
+ normal_lstm_duration = self._time_performance_run_normal_lstm(
+ test_config, x_train, y_train)
+
+ cudnn_vs_unified = cudnn_duration / unified_lstm_gpu_duration
+ unified_vs_normal = normal_lstm_duration / unified_lstm_gpu_duration
+
+ # TODO(scottzhu): reeanble the test after moving it to benchmark test suite.
+ # The current test has performance flakiness issue.
+ logging.info('Expect the performance of Unified LSTM is within 80% of '
+ 'CuDNN LSTM, got {0:.2f}%'.format(cudnn_vs_unified * 100))
+ logging.info('Expect the performance of Unified LSTM is more than 5 times'
+ ' of normal LSTM, got {0:.2f}'.format(unified_vs_normal))
+
+ # Assert the performance diff should be within 80% of the native cudnn.
+ # self.assertGreaterEqual(
+ # cudnn_vs_unified, 0.80,
+ # 'Expect the performance of Unified LSTM is within 80% of CuDNN LSTM, '
+ # 'but got {0:.2f}%'.format(cudnn_vs_unified * 100))
+ # # Assert the performance diff between CPU impl and GPU impl should be more
+ # # than 5 times.
+ # self.assertGreaterEqual(
+ # unified_vs_normal, 5,
+ # 'Expect the performance of Unified LSTM is more than 5 times of '
+ # 'normal LSTM, but got {0:.2f}'.format(unified_vs_normal))
+
class UnifiedLSTM(RNN):
def __init__(self,
units,
+ activation='tanh',
+ recurrent_activation='hard_sigmoid',
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
@@ -196,7 +352,8 @@
cell_spec = collections.namedtuple('cell', ['state_size', 'output_size'])
self.cell = cell_spec(
state_size=(self.units, self.units), output_size=self.units)
-
+ self.activation = activations.get(activation)
+ self.recurrent_activation = activations.get(recurrent_activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
@@ -290,12 +447,25 @@
# Reverse time axis.
inputs = K.reverse(inputs, 1)
- outputs, [new_h, new_c], runtime = normal_lstm(
- inputs, initial_state[0], initial_state[1], self.kernel,
- self.recurrent_kernel, self.bias, self.units)
+ if ops.executing_eagerly_outside_functions():
+ if context.num_gpus() > 0:
+ outputs, [new_h, new_c], runtime = cudnn_lstm(
+ inputs, initial_state[0], initial_state[1], self.kernel,
+ self.recurrent_kernel, self.bias, self.units)
+ else:
+ outputs, [new_h, new_c], runtime = normal_lstm(
+ inputs, initial_state[0], initial_state[1], self.kernel,
+ self.recurrent_kernel, self.bias, self.units, self.activation,
+ self.recurrent_activation)
+ else:
+ outputs, [new_h, new_c], runtime = normal_lstm(
+ inputs, initial_state[0], initial_state[1], self.kernel,
+ self.recurrent_kernel, self.bias, self.units, self.activation,
+ self.recurrent_activation)
- function.register(cudnn_lstm, inputs, initial_state[0], initial_state[1],
- self.kernel, self.recurrent_kernel, self.bias, self.units)
+ function.register(cudnn_lstm, inputs, initial_state[0], initial_state[1],
+ self.kernel, self.recurrent_kernel, self.bias,
+ self.units)
states = [new_h, new_c]
@@ -385,7 +555,8 @@
'experimental_api_implements': 'lstm',
'experimental_api_preferred_device': 'CPU'
})
-def normal_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, units):
+def normal_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, units,
+ activation, recurrent_activation):
input_shape = K.int_shape(inputs)
timesteps = input_shape[1]
@@ -405,12 +576,12 @@
z2 = z[:, 2 * units:3 * units]
z3 = z[:, 3 * units:]
- i = activations.get('hard_sigmoid')(z0)
- f = activations.get('hard_sigmoid')(z1)
- c = f * c_tm1 + i * activations.get('tanh')(z2)
- o = activations.get('hard_sigmoid')(z3)
+ i = recurrent_activation(z0)
+ f = recurrent_activation(z1)
+ c = f * c_tm1 + i * activation(z2)
+ o = recurrent_activation(z3)
- h = o * activations.get('tanh')(c)
+ h = o * activation(c)
return h, [h, c]
_, outputs, new_states = K.rnn(
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index 27419a1..67b1541 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -23,8 +23,8 @@
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
+from tensorflow.python.keras.engine.input_spec import InputSpec
from tensorflow.python.keras.layers.recurrent import _standardize_args
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index 9f548bf..1bd9f72 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -19,16 +19,251 @@
from __future__ import division
from __future__ import print_function
+import abc
+
import six
+from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras.utils.losses_utils import compute_weighted_loss
+from tensorflow.python.keras.utils.losses_utils import ReductionV2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.util.tf_export import tf_export
+class Loss(object):
+ """Loss base class.
+
+ To be implemented by subclasses:
+ * `call()`: Contains the logic for loss calculation using `y_true`, `y_pred`.
+
+ Example subclass implementation:
+ ```
+ class MeanSquaredError(Loss):
+ def call(self, y_true, y_pred):
+ y_pred = ops.convert_to_tensor(y_pred)
+ y_true = math_ops.cast(y_true, y_pred.dtype)
+ return K.mean(math_ops.square(y_pred - y_true), axis=-1)
+ ```
+
+ Args:
+ reduction: Type of `tf.losses.Reduction` to apply to loss. Default value is
+ `SUM_OVER_BATCH_SIZE`.
+ name: Optional name for the op.
+ """
+
+ def __init__(self, reduction=ReductionV2.SUM_OVER_BATCH_SIZE, name=None):
+ self.reduction = reduction
+ self.name = name
+
+ def __call__(self, y_true, y_pred, sample_weight=None):
+ """Invokes the `Loss` instance.
+
+ Args:
+ y_true: Ground truth values.
+ y_pred: The predicted values.
+ sample_weight: Optional `Tensor` whose rank is either 0, or the same rank
+ as `y_true`, or is broadcastable to `y_true`. `sample_weight` acts as a
+ coefficient for the loss. If a scalar is provided, then the loss is
+ simply scaled by the given value. If `sample_weight` is a tensor of size
+ `[batch_size]`, then the total loss for each sample of the batch is
+ rescaled by the corresponding element in the `sample_weight` vector. If
+ the shape of `sample_weight` matches the shape of `y_pred`, then the
+ loss of each measurable element of `y_pred` is scaled by the
+ corresponding value of `sample_weight`.
+
+ Returns:
+ Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
+ shape as `y_true`; otherwise, it is scalar.
+
+ Raises:
+ ValueError: If the shape of `sample_weight` is invalid.
+ """
+ with ops.name_scope(self.name, format(self.__class__.__name__),
+ (y_pred, y_true, sample_weight)):
+ losses = self.call(y_true, y_pred)
+ return compute_weighted_loss(
+ losses, sample_weight, reduction=self.reduction)
+
+ @classmethod
+ def from_config(cls, config):
+ """Instantiates a `Loss` from its config (output of `get_config()`).
+
+ Args:
+ config: Output of `get_config()`.
+
+ Returns:
+ A `Loss` instance.
+ """
+ return cls(**config)
+
+ def get_config(self):
+ return {'reduction': self.reduction, 'name': self.name}
+
+ @abc.abstractmethod
+ def call(self, y_true, y_pred):
+ """Invokes the `Loss` instance.
+
+ Args:
+ y_true: Ground truth values, with the same shape as 'y_pred'.
+ y_pred: The predicted values.
+ """
+ NotImplementedError('Must be implemented in subclasses.')
+
+
+@tf_export('losses.MeanSquaredError', 'keras.losses.MeanSquaredError')
+class MeanSquaredError(Loss):
+ """Computes the mean of squares of errors between labels and predictions.
+
+ For example, if `y_true` is [0., 0., 1., 1.] and `y_pred` is [1., 1., 1., 0.]
+ then the mean squared error value is 3/4 (0.75).
+
+ Usage:
+
+ ```python
+ mse = tf.losses.MeanSquaredError()
+ loss = mse([0., 0., 1., 1.], [1., 1., 1., 0.])
+ print('Loss: ', loss.numpy()) # Loss: 0.75
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss=tf.losses.MeanSquaredError())
+ ```
+ """
+
+ def call(self, y_true, y_pred):
+ """Invokes the `MeanSquaredError` instance.
+
+ Args:
+ y_true: Ground truth values.
+ y_pred: The predicted values.
+
+ Returns:
+ Mean squared error losses.
+ """
+ y_pred = ops.convert_to_tensor(y_pred)
+ y_true = math_ops.cast(y_true, y_pred.dtype)
+ return mean_squared_error(y_true, y_pred)
+
+
+class MeanAbsoluteError(Loss):
+ """Computes the mean of absolute difference between labels and predictions.
+
+ For example, if `y_true` is [0., 0., 1., 1.] and `y_pred` is [1., 1., 1., 0.]
+ then the mean absolute error value is 3/4 (0.75).
+
+ Usage:
+
+ ```python
+ mae = tf.losses.MeanAbsoluteError()
+ loss = mae([0., 0., 1., 1.], [1., 1., 1., 0.])
+ print('Loss: ', loss.numpy()) # Loss: 0.75
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss=tf.losses.MeanAbsoluteError())
+ ```
+ """
+
+ def call(self, y_true, y_pred):
+ """Invokes the `MeanAbsoluteError` instance.
+
+ Args:
+ y_true: Ground truth values.
+ y_pred: The predicted values.
+
+ Returns:
+ Mean absolute error losses.
+ """
+ y_pred = ops.convert_to_tensor(y_pred)
+ y_true = math_ops.cast(y_true, y_pred.dtype)
+ return mean_absolute_error(y_true, y_pred)
+
+
+class MeanAbsolutePercentageError(Loss):
+ """Computes the mean absolute percentage error between `y_true` and `y_pred`.
+
+ For example, if `y_true` is [0., 0., 1., 1.] and `y_pred` is [1., 1., 1., 0.]
+ then the mean absolute percentage error value is 5e+08.
+
+ Usage:
+
+ ```python
+ mape = tf.losses.MeanAbsolutePercentageError()
+ loss = mape([0., 0., 1., 1.], [1., 1., 1., 0.])
+ print('Loss: ', loss.numpy()) # Loss: 5e+08
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss=tf.losses.MeanAbsolutePercentageError())
+ ```
+ """
+
+ def call(self, y_true, y_pred):
+ """Invokes the `MeanAbsolutePercentageError` instance.
+
+ Args:
+ y_true: Ground truth values.
+ y_pred: The predicted values.
+
+ Returns:
+ Mean absolute percentage error losses.
+ """
+ y_pred = ops.convert_to_tensor(y_pred)
+ y_true = math_ops.cast(y_true, y_pred.dtype)
+ return mean_absolute_percentage_error(y_true, y_pred)
+
+
+class MeanSquaredLogarithmicError(Loss):
+ """Computes the mean squared logarithmic error between `y_true` and `y_pred`.
+
+ For example, if `y_true` is [0., 0., 1., 1.] and `y_pred` is [1., 1., 1., 0.]
+ then the mean squared logarithmic error value is 0.36034.
+
+ Usage:
+
+ ```python
+ msle = tf.losses.MeanSquaredLogarithmicError()
+ loss = msle([0., 0., 1., 1.], [1., 1., 1., 0.])
+ print('Loss: ', loss.numpy()) # Loss: 0.36034
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss=tf.losses.MeanSquaredLogarithmicError())
+ ```
+ """
+
+ def call(self, y_true, y_pred):
+ """Invokes the `MeanSquaredLogarithmicError` instance.
+
+ Args:
+ y_true: Ground truth values.
+ y_pred: The predicted values.
+
+ Returns:
+ Mean squared logarithmic error losses.
+ """
+ y_pred = ops.convert_to_tensor(y_pred)
+ y_true = math_ops.cast(y_true, y_pred.dtype)
+ return mean_squared_logarithmic_error(y_true, y_pred)
+
+
@tf_export('keras.metrics.mean_squared_error',
'keras.metrics.mse',
'keras.metrics.MSE',
@@ -197,3 +432,9 @@
else:
raise ValueError('Could not interpret '
'loss function identifier:', identifier)
+
+
+LABEL_DTYPES_FOR_LOSSES = {
+ losses_impl.sparse_softmax_cross_entropy: 'int32',
+ sparse_categorical_crossentropy: 'int32'
+}
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index c701527..d80b272 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -24,6 +24,9 @@
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
try:
@@ -138,5 +141,305 @@
loaded_model.predict(np.random.rand(128, 2))
+@test_util.run_all_in_graph_and_eager_modes
+class MeanSquaredErrorTest(test.TestCase):
+
+ def test_config(self):
+ mse_obj = keras.losses.MeanSquaredError(
+ reduction=keras.losses.ReductionV2.SUM, name='mse_1')
+ self.assertEqual(mse_obj.name, 'mse_1')
+ self.assertEqual(mse_obj.reduction, keras.losses.ReductionV2.SUM)
+
+ def test_all_correct_unweighted(self):
+ mse_obj = keras.losses.MeanSquaredError()
+ y_true = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
+ loss = mse_obj(y_true, y_true)
+ self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
+
+ def test_unweighted(self):
+ mse_obj = keras.losses.MeanSquaredError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mse_obj(y_true, y_pred)
+ self.assertAlmostEqual(self.evaluate(loss), 49.5, 3)
+
+ def test_scalar_weighted(self):
+ mse_obj = keras.losses.MeanSquaredError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mse_obj(y_true, y_pred, sample_weight=2.3)
+ self.assertAlmostEqual(self.evaluate(loss), 113.85, 3)
+
+ def test_sample_weighted(self):
+ mse_obj = keras.losses.MeanSquaredError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
+ loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 767.8 / 6, 3)
+
+ def test_timestep_weighted(self):
+ mse_obj = keras.losses.MeanSquaredError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3, 1),
+ dtype=dtypes.float32)
+ sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
+ loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 587 / 6, 3)
+
+ def test_zero_weighted(self):
+ mse_obj = keras.losses.MeanSquaredError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mse_obj(y_true, y_pred, sample_weight=0)
+ self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
+
+ def test_invalid_sample_weight(self):
+ mse_obj = keras.losses.MeanSquaredError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3, 1))
+ sample_weight = constant_op.constant([3, 6, 5, 0], shape=(2, 2))
+ with self.assertRaisesRegexp(
+ ValueError, r'Shapes \(2, 2\) and \(2, 3\) are incompatible'):
+ mse_obj(y_true, y_pred, sample_weight=sample_weight)
+
+ def test_no_reduction(self):
+ mse_obj = keras.losses.MeanSquaredError(
+ reduction=keras.losses.ReductionV2.NONE)
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mse_obj(y_true, y_pred, sample_weight=2.3)
+ loss = self.evaluate(loss)
+ self.assertArrayNear(loss, [84.3333, 143.3666], 1e-3)
+
+ def test_sum_reduction(self):
+ mse_obj = keras.losses.MeanSquaredError(
+ reduction=keras.losses.ReductionV2.SUM)
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mse_obj(y_true, y_pred, sample_weight=2.3)
+ self.assertAlmostEqual(self.evaluate(loss), 227.69998, 3)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class MeanAbsoluteErrorTest(test.TestCase):
+
+ def test_config(self):
+ mae_obj = keras.losses.MeanAbsoluteError(
+ reduction=keras.losses.ReductionV2.SUM, name='mae_1')
+ self.assertEqual(mae_obj.name, 'mae_1')
+ self.assertEqual(mae_obj.reduction, keras.losses.ReductionV2.SUM)
+
+ def test_all_correct_unweighted(self):
+ mae_obj = keras.losses.MeanAbsoluteError()
+ y_true = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
+ loss = mae_obj(y_true, y_true)
+ self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
+
+ def test_unweighted(self):
+ mae_obj = keras.losses.MeanAbsoluteError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mae_obj(y_true, y_pred)
+ self.assertAlmostEqual(self.evaluate(loss), 5.5, 3)
+
+ def test_scalar_weighted(self):
+ mae_obj = keras.losses.MeanAbsoluteError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mae_obj(y_true, y_pred, sample_weight=2.3)
+ self.assertAlmostEqual(self.evaluate(loss), 12.65, 3)
+
+ def test_sample_weighted(self):
+ mae_obj = keras.losses.MeanAbsoluteError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
+ loss = mae_obj(y_true, y_pred, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 81.4 / 6, 3)
+
+ def test_timestep_weighted(self):
+ mae_obj = keras.losses.MeanAbsoluteError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3, 1),
+ dtype=dtypes.float32)
+ sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
+ loss = mae_obj(y_true, y_pred, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 83 / 6, 3)
+
+ def test_zero_weighted(self):
+ mae_obj = keras.losses.MeanAbsoluteError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mae_obj(y_true, y_pred, sample_weight=0)
+ self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
+
+ def test_invalid_sample_weight(self):
+ mae_obj = keras.losses.MeanAbsoluteError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3], shape=(2, 3, 1))
+ sample_weight = constant_op.constant([3, 6, 5, 0], shape=(2, 2))
+ with self.assertRaisesRegexp(
+ ValueError, r'Shapes \(2, 2\) and \(2, 3\) are incompatible'):
+ mae_obj(y_true, y_pred, sample_weight=sample_weight)
+
+ def test_no_reduction(self):
+ mae_obj = keras.losses.MeanAbsoluteError(
+ reduction=keras.losses.ReductionV2.NONE)
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mae_obj(y_true, y_pred, sample_weight=2.3)
+ loss = self.evaluate(loss)
+ self.assertArrayNear(loss, [10.7333, 14.5666], 1e-3)
+
+ def test_sum_reduction(self):
+ mae_obj = keras.losses.MeanAbsoluteError(
+ reduction=keras.losses.ReductionV2.SUM)
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mae_obj(y_true, y_pred, sample_weight=2.3)
+ self.assertAlmostEqual(self.evaluate(loss), 25.29999, 3)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class MeanAbsolutePercentageErrorTest(test.TestCase):
+
+ def test_config(self):
+ mape_obj = keras.losses.MeanAbsolutePercentageError(
+ reduction=keras.losses.ReductionV2.SUM, name='mape_1')
+ self.assertEqual(mape_obj.name, 'mape_1')
+ self.assertEqual(mape_obj.reduction, keras.losses.ReductionV2.SUM)
+
+ def test_unweighted(self):
+ mape_obj = keras.losses.MeanAbsolutePercentageError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mape_obj(y_true, y_pred)
+ self.assertAlmostEqual(self.evaluate(loss), 211.8518, 3)
+
+ def test_scalar_weighted(self):
+ mape_obj = keras.losses.MeanAbsolutePercentageError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mape_obj(y_true, y_pred, sample_weight=2.3)
+ self.assertAlmostEqual(self.evaluate(loss), 487.259, 3)
+
+ def test_sample_weighted(self):
+ mape_obj = keras.losses.MeanAbsolutePercentageError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
+ loss = mape_obj(y_true, y_pred, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 422.8888, 3)
+
+ def test_timestep_weighted(self):
+ mape_obj = keras.losses.MeanAbsolutePercentageError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3, 1),
+ dtype=dtypes.float32)
+ sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
+ loss = mape_obj(y_true, y_pred, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 694.4445, 3)
+
+ def test_zero_weighted(self):
+ mape_obj = keras.losses.MeanAbsolutePercentageError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = mape_obj(y_true, y_pred, sample_weight=0)
+ self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class MeanSquaredLogarithmicErrorTest(test.TestCase):
+
+ def test_config(self):
+ msle_obj = keras.losses.MeanSquaredLogarithmicError(
+ reduction=keras.losses.ReductionV2.SUM, name='mape_1')
+ self.assertEqual(msle_obj.name, 'mape_1')
+ self.assertEqual(msle_obj.reduction, keras.losses.ReductionV2.SUM)
+
+ def test_unweighted(self):
+ msle_obj = keras.losses.MeanSquaredLogarithmicError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = msle_obj(y_true, y_pred)
+ self.assertAlmostEqual(self.evaluate(loss), 1.4370, 3)
+
+ def test_scalar_weighted(self):
+ msle_obj = keras.losses.MeanSquaredLogarithmicError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = msle_obj(y_true, y_pred, sample_weight=2.3)
+ self.assertAlmostEqual(self.evaluate(loss), 3.3051, 3)
+
+ def test_sample_weighted(self):
+ msle_obj = keras.losses.MeanSquaredLogarithmicError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
+ loss = msle_obj(y_true, y_pred, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 3.7856, 3)
+
+ def test_timestep_weighted(self):
+ msle_obj = keras.losses.MeanSquaredLogarithmicError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3, 1),
+ dtype=dtypes.float32)
+ sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
+ loss = msle_obj(y_true, y_pred, sample_weight=sample_weight)
+ self.assertAlmostEqual(self.evaluate(loss), 2.6473, 3)
+
+ def test_zero_weighted(self):
+ msle_obj = keras.losses.MeanSquaredLogarithmicError()
+ y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+ y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
+ shape=(2, 3),
+ dtype=dtypes.float32)
+ loss = msle_obj(y_true, y_pred, sample_weight=0)
+ self.assertAlmostEqual(self.evaluate(loss), 0.0, 3)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 7848be3..b74b6cc 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -48,9 +48,10 @@
from tensorflow.python.keras.losses import squared_hinge
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
+from tensorflow.python.keras.utils.generic_utils import to_list
+from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -64,13 +65,6 @@
from tensorflow.tools.docs import doc_controls
-def check_is_tensor_or_operation(x, name):
- """Raises type error if the given input is not a tensor or operation."""
- if not (isinstance(x, ops.Tensor) or isinstance(x, ops.Operation)):
- raise TypeError('{0} must be a Tensor or Operation, given: {1}'.format(
- name, x))
-
-
def clone_metric(metric):
"""Returns a clone of the metric if stateful, otherwise returns it as is."""
if isinstance(metric, Metric):
@@ -103,8 +97,6 @@
update_op = update_state_fn(*args, **kwargs)
if update_op is not None: # update_op will be None in eager execution.
metric_obj.add_update(update_op, inputs=True)
- check_is_tensor_or_operation(
- update_op, 'Metric {0}\'s update'.format(metric_obj.name))
return update_op
return tf_decorator.make_decorator(update_state_fn, decorated)
@@ -129,7 +121,7 @@
`merge_call()`.
"""
- def decorated(metric_obj, *args):
+ def decorated(_, *args):
"""Decorated function with merge_call."""
replica_context = distribution_strategy_context.get_replica_context()
if replica_context is None: # if in cross replica context already
@@ -150,8 +142,6 @@
# replica mode and compute a value in cross replica mode.
result_t = replica_context.merge_call(
merge_fn_wrapper, args=(result_fn,) + args)
- check_is_tensor_or_operation(result_t,
- 'Metric {0}\'s result'.format(metric_obj.name))
return result_t
return tf_decorator.make_decorator(result_fn, decorated)
@@ -172,77 +162,6 @@
return inner
-def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
- """Squeeze or expand last dimension if needed.
-
- 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
- (using `confusion_matrix.remove_squeezable_dimensions`).
- 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
- from the new rank of `y_pred`.
- If `sample_weight` is scalar, it is kept scalar.
-
- This will use static shape if available. Otherwise, it will add graph
- operations, which could result in a performance hit.
-
- Args:
- y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
- y_true: Optional label `Tensor` whose dimensions match `y_pred`.
- sample_weight: Optional weight scalar or `Tensor` whose dimensions match
- `y_pred`.
-
- Returns:
- Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
- the last dimension squeezed,
- `sample_weight` could be extended by one dimension.
- """
- if y_true is not None:
- # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
- y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
- y_true, y_pred)
-
- if sample_weight is None:
- return y_pred, y_true, None
-
- sample_weight = ops.convert_to_tensor(sample_weight)
- weights_shape = sample_weight.get_shape()
- weights_rank = weights_shape.ndims
- if weights_rank == 0: # If weights is scalar, do nothing.
- return y_pred, y_true, sample_weight
-
- y_pred_shape = y_pred.get_shape()
- y_pred_rank = y_pred_shape.ndims
- if (y_pred_rank is not None) and (weights_rank is not None):
- # Use static rank.
- if weights_rank - y_pred_rank == 1:
- sample_weight = array_ops.squeeze(sample_weight, [-1])
- elif y_pred_rank - weights_rank == 1:
- sample_weight = array_ops.expand_dims(sample_weight, [-1])
- return y_pred, y_true, sample_weight
-
- # Use dynamic rank.
- weights_rank_tensor = array_ops.rank(sample_weight)
- rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
- maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
-
- def _maybe_expand_weights():
- return control_flow_ops.cond(
- math_ops.equal(rank_diff,
- -1), lambda: array_ops.expand_dims(sample_weight, [-1]),
- lambda: sample_weight)
-
- def _maybe_adjust_weights():
- return control_flow_ops.cond(
- math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
- _maybe_expand_weights)
-
- # squeeze or expand last dim of `sample_weight` if its rank differs by 1
- # from the new rank of `y_pred`.
- sample_weight = control_flow_ops.cond(
- math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
- _maybe_adjust_weights)
- return y_pred, y_true, sample_weight
-
-
class _ConfusionMatrix(Enum):
TRUE_POSITIVES = 'tp'
FALSE_POSITIVES = 'fp'
@@ -286,7 +205,8 @@
y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.
y_pred: A floating point `Tensor` of arbitrary shape and whose values are in
the range `[0, 1]`.
- thresholds: A python list or tuple of float thresholds in `[0, 1]`.
+ thresholds: A float value or a python list or tuple of float thresholds in
+ `[0, 1]`.
sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
`y_true`, and must be broadcastable to `y_true` (i.e., all dimensions must
be either `1`, or the same as the corresponding `y_true` dimension).
@@ -301,7 +221,9 @@
"""
if variables_to_update is None:
return
- y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
+ y_true = ops.convert_to_tensor(y_true)
+ y_pred = ops.convert_to_tensor(y_pred)
+ y_pred.shape.assert_is_compatible_with(y_true.shape)
if not any(
key for key in variables_to_update if key in list(_ConfusionMatrix)):
@@ -333,6 +255,7 @@
math_ops.cast(y_pred, dtype=dtypes.float32),
math_ops.cast(y_true, dtype=dtypes.bool), sample_weight)
+ thresholds = to_list(thresholds)
num_thresholds = len(thresholds)
num_predictions = array_ops.size(y_pred)
@@ -401,7 +324,7 @@
class Metric(Layer):
"""Encapsulates metric logic and state.
- Usage with eager execution:
+ Usage:
```python
m = SomeMetric(...)
@@ -410,19 +333,6 @@
print('Final result: ', m.result().numpy())
```
- Usage with graph execution:
-
- ```python
- m = SomeMetric(...)
- init_op = tf.variables_initializer(m.variables) # Initialize variables
- with tf.Session() as sess:
- sess.run(init_op)
- for input in ...:
- update_op = m.update_state(input)
- sess.run(update_op)
- print('Final result: ', sess.run(m.result()))
- ```
-
Usage with tf.keras API:
```python
@@ -600,15 +510,35 @@
### End: For use by subclasses ###
+@tf_export('metrics.Mean', 'keras.metrics.Mean')
class Mean(Metric):
"""Computes the (weighted) mean of the given values.
+ For example, if values is [1, 3, 5, 7] then the mean is 4.
+ If the weights were specified as [1, 1, 0, 0] then the mean would be 2.
+
This metric creates two variables, `total` and `count` that are used to
compute the average of `values`. This average is ultimately returned as `mean`
which is an idempotent operation that simply divides `total` by `count`.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.Mean()
+ m.update_state([1, 3, 5, 7])
+ print('Final result: ', m.result().numpy()) # Final result: 4.0
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.add_metric(metrics_module.Mean(name='mean_1')(outputs))
+ model.compile('sgd', loss='mse')
+ ```
"""
def __init__(self, name='mean', dtype=None):
@@ -666,8 +596,7 @@
# updated.
update_total_op = state_ops.assign_add(self.total, values)
with ops.control_dependencies([update_total_op]):
- update_count_op = state_ops.assign_add(self.count, num_values)
- return ops.convert_to_tensor(update_count_op)
+ return state_ops.assign_add(self.count, num_values)
def result(self):
return math_ops.div_no_nan(self.total, self.count)
@@ -721,9 +650,14 @@
return dict(list(base_config.items()) + list(config.items()))
-class BinaryAccuracy(MeanMetricWrapper):
+@tf_export('metrics.Accuracy', 'keras.metrics.Accuracy')
+class Accuracy(MeanMetricWrapper):
"""Calculates how often predictions matches labels.
+ For example, if `y_true` is [1, 2, 3, 4] and `y_pred` is [0, 2, 3, 4]
+ then the accuracy is 3/4 or .75. If the weights were specified as
+ [1, 1, 0, 0] then the accuracy would be 1/2 or .5.
+
This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `binary accuracy`: an idempotent operation that simply
@@ -731,6 +665,63 @@
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.Accuracy()
+ m.update_state([1, 2, 3, 4], [0, 2, 3, 4])
+ print('Final result: ', m.result().numpy()) # Final result: 0.75
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.Accuracy()])
+ ```
+ """
+
+ def __init__(self, name='accuracy', dtype=None):
+ super(Accuracy, self).__init__(accuracy, name, dtype=dtype)
+
+ @classmethod
+ def from_config(cls, config):
+ if 'fn' in config:
+ config.pop('fn')
+ return super(Accuracy, cls).from_config(config)
+
+
+@tf_export('metrics.BinaryAccuracy', 'keras.metrics.BinaryAccuracy')
+class BinaryAccuracy(MeanMetricWrapper):
+ """Calculates how often predictions matches labels.
+
+ For example, if `y_true` is [1, 1, 0, 0] and `y_pred` is [0.98, 1, 0, 0.6]
+ then the binary accuracy is 3/4 or .75. If the weights were specified as
+ [1, 0, 0, 1] then the binary accuracy would be 1/2 or .5.
+
+ This metric creates two local variables, `total` and `count` that are used to
+ compute the frequency with which `y_pred` matches `y_true`. This frequency is
+ ultimately returned as `binary accuracy`: an idempotent operation that simply
+ divides `total` by `count`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.BinaryAccuracy()
+ m.update_state([1, 1, 0, 0], [0.98, 1, 0, 0.6])
+ print('Final result: ', m.result().numpy()) # Final result: 0.75
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.BinaryAccuracy()])
+ ```
"""
def __init__(self, name='binary_accuracy', dtype=None, threshold=0.5):
@@ -752,16 +743,41 @@
return super(BinaryAccuracy, cls).from_config(config)
+@tf_export(
+ 'metrics.CategoricalAccuracy', 'keras.metrics.CategoricalAccuracy')
class CategoricalAccuracy(MeanMetricWrapper):
"""Calculates how often predictions matches labels.
+ For example, if `y_true` is [[0, 0, 1], [0, 1, 0]] and `y_pred` is
+ [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
+ If the weights were specified as [0.7, 0.3] then the categorical accuracy
+ would be .3.
+
This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `categorical accuracy`: an idempotent operation that
simply divides `total` by `count`.
+ `y_pred` and `y_true` should be passed in as vectors of probabilities, rather
+ than as labels. If necessary, use `tf.one_hot` to expand `y_true` as a vector.
+
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.CategoricalAccuracy()
+ m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
+ print('Final result: ', m.result().numpy()) # Final result: 0.5
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.CategoricalAccuracy()])
+ ```
"""
def __init__(self, name='categorical_accuracy', dtype=None):
@@ -781,9 +797,17 @@
return super(CategoricalAccuracy, cls).from_config(config)
+@tf_export(
+ 'metrics.SparseCategoricalAccuracy',
+ 'keras.metrics.SparseCategoricalAccuracy')
class SparseCategoricalAccuracy(MeanMetricWrapper):
"""Calculates how often predictions matches integer labels.
+ For example, if `y_true` is [[2], [1]] and `y_pred` is
+ [[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
+ If the weights were specified as [0.7, 0.3] then the categorical accuracy
+ would be .3.
+
This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `sparse categorical accuracy`: an idempotent operation
@@ -791,6 +815,24 @@
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.SparseCategoricalAccuracy()
+ m.update_state([[2], [1]], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
+ print('Final result: ', m.result().numpy()) # Final result: 0.5
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile(
+ 'sgd',
+ loss='mse',
+ metrics=[tf.metrics.SparseCategoricalAccuracy()])
+ ```
"""
def __init__(self, name='sparse_categorical_accuracy', dtype=None):
@@ -816,21 +858,22 @@
Args:
confusion_matrix_cond: One of `_ConfusionMatrix` conditions.
- thresholds: (Optional) Defaults to [0.5]. A python list/tuple of float
- threshold values in [0, 1]. A threshold is compared with prediction
- values to determine the truth value of predictions (i.e., above the
- threshold is `true`, below is `false`). One metric value is generated
- for each threshold value.
+ thresholds: (Optional) Defaults to 0.5. A float value or a python
+ list/tuple of float threshold values in [0, 1]. A threshold is compared
+ with prediction values to determine the truth value of predictions
+ (i.e., above the threshold is `true`, below is `false`). One metric
+ value is generated for each threshold value.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
super(_ConfusionMatrixConditionCount, self).__init__(name=name, dtype=dtype)
self._confusion_matrix_cond = confusion_matrix_cond
- self.thresholds = [0.5] if thresholds is None else thresholds
- _assert_thresholds_range(self.thresholds)
+ self.thresholds = 0.5 if thresholds is None else thresholds
+ thresholds = to_list(thresholds)
+ _assert_thresholds_range(thresholds)
self.accumulator = self.add_weight(
'accumulator',
- shape=(len(self.thresholds),),
+ shape=(len(thresholds),),
initializer=init_ops.zeros_initializer)
def update_state(self, y_true, y_pred, sample_weight=None):
@@ -851,29 +894,53 @@
}, y_true, y_pred, self.thresholds, sample_weight)
def result(self):
- return ops.convert_to_tensor(self.accumulator)
+ if isinstance(self.thresholds, (list, tuple)):
+ result = self.accumulator
+ else:
+ result = self.accumulator[0]
+ return ops.convert_to_tensor(result)
+@tf_export('metrics.FalsePositives', 'keras.metrics.FalsePositives')
class FalsePositives(_ConfusionMatrixConditionCount):
"""Calculates the number of false positives.
+ For example, if `y_true` is [0, 1, 0, 0] and `y_pred` is [0, 0, 1, 1]
+ then the false positives value is 2. If the weights were specified as
+ [0, 0, 1, 0] then the false positives value would be 1.
+
If `sample_weight` is given, calculates the sum of the weights of
false positives. This metric creates one local variable, `accumulator`
that is used to keep track of the number of false positives.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.FalsePositives()
+ m.update_state([0, 1, 0, 0], [0, 0, 1, 1])
+ print('Final result: ', m.result().numpy()) # Final result: 2
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.FalsePositives()])
+ ```
"""
def __init__(self, thresholds=None, name=None, dtype=None):
"""Creates a `FalsePositives` instance.
Args:
- thresholds: (Optional) Defaults to [0.5]. A python list/tuple of float
- threshold values in [0, 1]. A threshold is compared with prediction
- values to determine the truth value of predictions (i.e., above the
- threshold is `true`, below is `false`). One metric value is generated
- for each threshold value.
+ thresholds: (Optional) Defaults to 0.5. A float value or a python
+ list/tuple of float threshold values in [0, 1]. A threshold is compared
+ with prediction values to determine the truth value of predictions
+ (i.e., above the threshold is `true`, below is `false`). One metric
+ value is generated for each threshold value.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
@@ -884,26 +951,46 @@
dtype=dtype)
+@tf_export('metrics.FalseNegatives', 'keras.metrics.FalseNegatives')
class FalseNegatives(_ConfusionMatrixConditionCount):
"""Calculates the number of false negatives.
+ For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [0, 1, 0, 0]
+ then the false negatives value is 2. If the weights were specified as
+ [0, 0, 1, 0] then the false negatives value would be 1.
+
If `sample_weight` is given, calculates the sum of the weights of
false negatives. This metric creates one local variable, `accumulator`
that is used to keep track of the number of false negatives.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.FalseNegatives()
+ m.update_state([0, 1, 1, 1], [0, 1, 0, 0])
+ print('Final result: ', m.result().numpy()) # Final result: 2
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.FalseNegatives()])
+ ```
"""
def __init__(self, thresholds=None, name=None, dtype=None):
"""Creates a `FalseNegatives` instance.
Args:
- thresholds: (Optional) Defaults to [0.5]. A python list/tuple of float
- threshold values in [0, 1]. A threshold is compared with prediction
- values to determine the truth value of predictions (i.e., above the
- threshold is `true`, below is `false`). One metric value is generated
- for each threshold value.
+ thresholds: (Optional) Defaults to 0.5. A float value or a python
+ list/tuple of float threshold values in [0, 1]. A threshold is compared
+ with prediction values to determine the truth value of predictions
+ (i.e., above the threshold is `true`, below is `false`). One metric
+ value is generated for each threshold value.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
@@ -914,26 +1001,46 @@
dtype=dtype)
+@tf_export('metrics.TrueNegatives', 'keras.metrics.TrueNegatives')
class TrueNegatives(_ConfusionMatrixConditionCount):
"""Calculates the number of true negatives.
+ For example, if `y_true` is [0, 1, 0, 0] and `y_pred` is [1, 1, 0, 0]
+ then the true negatives value is 2. If the weights were specified as
+ [0, 0, 1, 0] then the true negatives value would be 1.
+
If `sample_weight` is given, calculates the sum of the weights of
true negatives. This metric creates one local variable, `accumulator`
that is used to keep track of the number of true negatives.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.TrueNegatives()
+ m.update_state([0, 1, 0, 0], [1, 1, 0, 0])
+ print('Final result: ', m.result().numpy()) # Final result: 2
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.TrueNegatives()])
+ ```
"""
def __init__(self, thresholds=None, name=None, dtype=None):
"""Creates a `TrueNegatives` instance.
Args:
- thresholds: (Optional) Defaults to [0.5]. A python list/tuple of float
- threshold values in [0, 1]. A threshold is compared with prediction
- values to determine the truth value of predictions (i.e., above the
- threshold is `true`, below is `false`). One metric value is generated
- for each threshold value.
+ thresholds: (Optional) Defaults to 0.5. A float value or a python
+ list/tuple of float threshold values in [0, 1]. A threshold is compared
+ with prediction values to determine the truth value of predictions
+ (i.e., above the threshold is `true`, below is `false`). One metric
+ value is generated for each threshold value.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
@@ -944,26 +1051,46 @@
dtype=dtype)
+@tf_export('metrics.TruePositives', 'keras.metrics.TruePositives')
class TruePositives(_ConfusionMatrixConditionCount):
"""Calculates the number of true positives.
+ For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1]
+ then the true positives value is 2. If the weights were specified as
+ [0, 0, 1, 0] then the true positives value would be 1.
+
If `sample_weight` is given, calculates the sum of the weights of
true positives. This metric creates one local variable, `true_positives`
that is used to keep track of the number of true positives.
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.TruePositives()
+ m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
+ print('Final result: ', m.result().numpy()) # Final result: 2
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.TruePositives()])
+ ```
"""
def __init__(self, thresholds=None, name=None, dtype=None):
"""Creates a `TruePositives` instance.
Args:
- thresholds: (Optional) Defaults to [0.5]. A python list/tuple of float
- threshold values in [0, 1]. A threshold is compared with prediction
- values to determine the truth value of predictions (i.e., above the
- threshold is `true`, below is `false`). One metric value is generated
- for each threshold value.
+ thresholds: (Optional) Defaults to 0.5. A float value or a python
+ list/tuple of float threshold values in [0, 1]. A threshold is compared
+ with prediction values to determine the truth value of predictions
+ (i.e., above the threshold is `true`, below is `false`). One metric
+ value is generated for each threshold value.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
"""
@@ -974,6 +1101,368 @@
dtype=dtype)
+@tf_export('metrics.Precision', 'keras.metrics.Precision')
+class Precision(Metric):
+ """Computes the precision of the predictions with respect to the labels.
+
+ For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1]
+ then the precision value is 2/(2+1) ie. 0.66. If the weights were specified as
+ [0, 0, 1, 0] then the precision value would be 1.
+
+ The metric creates two local variables, `true_positives` and `false_positives`
+ that are used to compute the precision. This value is ultimately returned as
+ `precision`, an idempotent operation that simply divides `true_positives`
+ by the sum of `true_positives` and `false_positives`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.Precision()
+ m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
+ print('Final result: ', m.result().numpy()) # Final result: 0.66
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.Precision()])
+ ```
+ """
+
+ def __init__(self, thresholds=None, name=None, dtype=None):
+ """Creates a `Precision` instance.
+
+ Args:
+ thresholds: (Optional) Defaults to 0.5. A float value or a python
+ list/tuple of float threshold values in [0, 1]. A threshold is compared
+ with prediction values to determine the truth value of predictions
+ (i.e., above the threshold is `true`, below is `false`). One metric
+ value is generated for each threshold value.
+ name: (Optional) string name of the metric instance.
+ dtype: (Optional) data type of the metric result.
+ """
+ super(Precision, self).__init__(name=name, dtype=dtype)
+ self.thresholds = 0.5 if thresholds is None else thresholds
+ thresholds = to_list(thresholds)
+ _assert_thresholds_range(thresholds)
+ self.tp = self.add_weight(
+ 'true_positives',
+ shape=(len(thresholds),),
+ initializer=init_ops.zeros_initializer)
+ self.fp = self.add_weight(
+ 'false_positives',
+ shape=(len(thresholds),),
+ initializer=init_ops.zeros_initializer)
+
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ """Accumulates true positive and false positive statistics.
+
+ Args:
+ y_true: The ground truth values.
+ y_pred: The predicted values.
+ sample_weight: Optional weighting of each example. Defaults to 1. Can be a
+ `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
+ be broadcastable to `y_true`.
+
+ Returns:
+ Update op.
+ """
+ return _update_confusion_matrix_variables({
+ _ConfusionMatrix.TRUE_POSITIVES: self.tp,
+ _ConfusionMatrix.FALSE_POSITIVES: self.fp
+ }, y_true, y_pred, self.thresholds, sample_weight)
+
+ def result(self):
+ result = math_ops.div_no_nan(self.tp, self.tp + self.fp)
+ return result if isinstance(self.thresholds, (list, tuple)) else result[0]
+
+
+@tf_export('metrics.Recall', 'keras.metrics.Recall')
+class Recall(Metric):
+ """Computes the recall of the predictions with respect to the labels.
+
+ For example, if `y_true` is [0, 1, 1, 1] and `y_pred` is [1, 0, 1, 1]
+ then the recall value is 2/(2+1) ie. 0.66. If the weights were specified as
+ [0, 0, 1, 0] then the recall value would be 1.
+
+ This metric creates two local variables, `true_positives` and
+ `false_negatives`, that are used to compute the recall. This value is
+ ultimately returned as `recall`, an idempotent operation that simply divides
+ `true_positives` by the sum of `true_positives` and `false_negatives`.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+
+ Usage:
+
+ ```python
+ m = tf.metrics.Recall()
+ m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
+ print('Final result: ', m.result().numpy()) # Final result: 0.66
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile('sgd', loss='mse', metrics=[tf.metrics.Recall()])
+ ```
+ """
+
+ def __init__(self, thresholds=None, name=None, dtype=None):
+ """Creates a `Recall` instance.
+
+ Args:
+ thresholds: (Optional) Defaults to 0.5. A float value or a python
+ list/tuple of float threshold values in [0, 1]. A threshold is compared
+ with prediction values to determine the truth value of predictions
+ (i.e., above the threshold is `true`, below is `false`). One metric
+ value is generated for each threshold value.
+ name: (Optional) string name of the metric instance.
+ dtype: (Optional) data type of the metric result.
+ """
+ super(Recall, self).__init__(name=name, dtype=dtype)
+ self.thresholds = 0.5 if thresholds is None else thresholds
+ thresholds = to_list(thresholds)
+ _assert_thresholds_range(thresholds)
+ self.tp = self.add_weight(
+ 'true_positives',
+ shape=(len(thresholds),),
+ initializer=init_ops.zeros_initializer)
+ self.fn = self.add_weight(
+ 'false_negatives',
+ shape=(len(thresholds),),
+ initializer=init_ops.zeros_initializer)
+
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ """Accumulates true positive and false negative statistics.
+
+ Args:
+ y_true: The ground truth values.
+ y_pred: The predicted values.
+ sample_weight: Optional weighting of each example. Defaults to 1. Can be a
+ `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
+ be broadcastable to `y_true`.
+
+ Returns:
+ Update op.
+ """
+ return _update_confusion_matrix_variables({
+ _ConfusionMatrix.TRUE_POSITIVES: self.tp,
+ _ConfusionMatrix.FALSE_NEGATIVES: self.fn
+ }, y_true, y_pred, self.thresholds, sample_weight)
+
+ def result(self):
+ result = math_ops.div_no_nan(self.tp, self.tp + self.fn)
+ return result if isinstance(self.thresholds, (list, tuple)) else result[0]
+
+
+@six.add_metaclass(abc.ABCMeta)
+class SensitivitySpecificityBase(Metric):
+ """Abstract base class for computing sensitivity and specificity.
+
+ For additional information about specificity and sensitivity, see the
+ following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
+ """
+
+ def __init__(self, value, num_thresholds=200, name=None, dtype=None):
+ super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
+ if num_thresholds <= 0:
+ raise ValueError('`num_thresholds` must be > 0.')
+ self.value = value
+ self.tp = self.add_weight(
+ 'true_positives',
+ shape=(num_thresholds,),
+ initializer=init_ops.zeros_initializer)
+ self.tn = self.add_weight(
+ 'true_negatives',
+ shape=(num_thresholds,),
+ initializer=init_ops.zeros_initializer)
+ self.fp = self.add_weight(
+ 'false_positives',
+ shape=(num_thresholds,),
+ initializer=init_ops.zeros_initializer)
+ self.fn = self.add_weight(
+ 'false_negatives',
+ shape=(num_thresholds,),
+ initializer=init_ops.zeros_initializer)
+
+ # Compute `num_thresholds` thresholds in [0, 1]
+ if num_thresholds == 1:
+ self.thresholds = [0.5]
+ else:
+ thresholds = [(i + 1) * 1.0 / (num_thresholds - 1)
+ for i in range(num_thresholds - 2)]
+ self.thresholds = [0.0] + thresholds + [1.0]
+
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ """Accumulates confusion matrix statistics.
+
+ Args:
+ y_true: The ground truth values.
+ y_pred: The predicted values.
+ sample_weight: Optional weighting of each example. Defaults to 1. Can be a
+ `Tensor` whose rank is either 0, or the same rank as `y_true`, and must
+ be broadcastable to `y_true`.
+
+ Returns:
+ Update op.
+ """
+ return _update_confusion_matrix_variables({
+ _ConfusionMatrix.TRUE_POSITIVES: self.tp,
+ _ConfusionMatrix.TRUE_NEGATIVES: self.tn,
+ _ConfusionMatrix.FALSE_POSITIVES: self.fp,
+ _ConfusionMatrix.FALSE_NEGATIVES: self.fn,
+ }, y_true, y_pred, self.thresholds, sample_weight)
+
+
+class SensitivityAtSpecificity(SensitivitySpecificityBase):
+ """Computes the sensitivity at a given specificity.
+
+ `Sensitivity` measures the proportion of actual positives that are correctly
+ identified as such (tp / (tp + fn)).
+ `Specificity` measures the proportion of actual negatives that are correctly
+ identified as such (tn / (tn + fp)).
+
+ This metric creates four local variables, `true_positives`, `true_negatives`,
+ `false_positives` and `false_negatives` that are used to compute the
+ sensitivity at the given specificity. The threshold for the given specificity
+ value is computed and used to evaluate the corresponding sensitivity.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+
+ For additional information about specificity and sensitivity, see the
+ following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
+
+ Usage:
+
+ ```python
+ m = tf.metrics.SensitivityAtSpecificity(0.4, num_thresholds=1)
+ m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
+ print('Final result: ', m.result().numpy()) # Final result: 0.5
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile(
+ 'sgd',
+ loss='mse',
+ metrics=[tf.metrics.SensitivityAtSpecificity()])
+ ```
+ """
+
+ def __init__(self, specificity, num_thresholds=200, name=None, dtype=None):
+ """Creates a `SensitivityAtSpecificity` instance.
+
+ Args:
+ specificity: A scalar value in range `[0, 1]`.
+ num_thresholds: (Optional) Defaults to 200. The number of thresholds to
+ use for matching the given specificity.
+ name: (Optional) string name of the metric instance.
+ dtype: (Optional) data type of the metric result.
+ """
+ if specificity < 0 or specificity > 1:
+ raise ValueError('`specificity` must be in the range [0, 1].')
+ super(SensitivityAtSpecificity, self).__init__(
+ specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
+
+ def result(self):
+ # Calculate specificities at all the thresholds.
+ specificities = math_ops.div_no_nan(self.tn, self.tn + self.fp)
+
+ # Find the index of the threshold where the specificity is closest to the
+ # given specificity.
+ min_index = math_ops.argmin(
+ math_ops.abs(specificities - self.value), axis=0)
+ min_index = math_ops.cast(min_index, dtypes.int32)
+
+ # Compute sensitivity at that index.
+ return math_ops.div_no_nan(self.tp[min_index],
+ self.tp[min_index] + self.fn[min_index])
+
+
+class SpecificityAtSensitivity(SensitivitySpecificityBase):
+ """Computes the specificity at a given sensitivity.
+
+ `Sensitivity` measures the proportion of actual positives that are correctly
+ identified as such (tp / (tp + fn)).
+ `Specificity` measures the proportion of actual negatives that are correctly
+ identified as such (tn / (tn + fp)).
+
+ This metric creates four local variables, `true_positives`, `true_negatives`,
+ `false_positives` and `false_negatives` that are used to compute the
+ specificity at the given sensitivity. The threshold for the given sensitivity
+ value is computed and used to evaluate the corresponding specificity.
+
+ If `sample_weight` is `None`, weights default to 1.
+ Use `sample_weight` of 0 to mask values.
+
+ For additional information about specificity and sensitivity, see the
+ following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
+
+ Usage:
+
+ ```python
+ m = tf.metrics.SpecificityAtSensitivity(0.8, num_thresholds=1)
+ m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
+ print('Final result: ', m.result().numpy()) # Final result: 1.0
+ ```
+
+ Usage with tf.keras API:
+
+ ```python
+ model = keras.models.Model(inputs, outputs)
+ model.compile(
+ 'sgd',
+ loss='mse',
+ metrics=[tf.metrics.SpecificityAtSensitivity()])
+ ```
+ """
+
+ def __init__(self, sensitivity, num_thresholds=200, name=None, dtype=None):
+ """Creates a `SpecificityAtSensitivity` instance.
+
+ Args:
+ sensitivity: A scalar value in range `[0, 1]`.
+ num_thresholds: (Optional) Defaults to 200. The number of thresholds to
+ use for matching the given specificity.
+ name: (Optional) string name of the metric instance.
+ dtype: (Optional) data type of the metric result.
+ """
+ if sensitivity < 0 or sensitivity > 1:
+ raise ValueError('`sensitivity` must be in the range [0, 1].')
+ super(SpecificityAtSensitivity, self).__init__(
+ sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype)
+
+ def result(self):
+ # Calculate sensitivities at all the thresholds.
+ sensitivities = math_ops.div_no_nan(self.tp, self.tp + self.fn)
+
+ # Find the index of the threshold where the sensitivity is closest to the
+ # given specificity.
+ min_index = math_ops.argmin(
+ math_ops.abs(sensitivities - self.value), axis=0)
+ min_index = math_ops.cast(min_index, dtypes.int32)
+
+ # Compute specificity at that index.
+ return math_ops.div_no_nan(self.tn[min_index],
+ self.tn[min_index] + self.fp[min_index])
+
+
+def accuracy(y_true, y_pred):
+ y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
+ if y_true.dtype != y_pred.dtype:
+ y_pred = math_ops.cast(y_pred, y_true.dtype)
+ return math_ops.cast(math_ops.equal(y_true, y_pred), K.floatx())
+
+
@tf_export('keras.metrics.binary_accuracy')
def binary_accuracy(y_true, y_pred, threshold=0.5):
threshold = math_ops.cast(threshold, y_pred.dtype)
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index c6a49c3..40611a5 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -19,6 +19,7 @@
from __future__ import print_function
import os
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
@@ -27,12 +28,10 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend as K
-from tensorflow.python.keras import layers
from tensorflow.python.keras import metrics
-from tensorflow.python.keras.engine.training import Model
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -48,7 +47,7 @@
output = metric(y_a, y_b)
self.assertEqual(K.eval(output).shape, (6,))
- def test_sparse_categorical_accuracy(self):
+ def test_sparse_categorical_accuracy_int(self):
with self.cached_session():
metric = metrics.sparse_categorical_accuracy
y_true = K.variable(np.random.randint(0, 7, (6,)))
@@ -129,116 +128,6 @@
result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
- def test_stateful_metrics(self):
- with self.cached_session():
- np.random.seed(1334)
-
- class BinaryTruePositives(layers.Layer):
- """Stateful Metric to count the total true positives over all batches.
-
- Assumes predictions and targets of shape `(samples, 1)`.
-
- Arguments:
- threshold: Float, lower limit on prediction value that counts as a
- positive class prediction.
- name: String, name for the metric.
- """
-
- def __init__(self, name='true_positives', **kwargs):
- super(BinaryTruePositives, self).__init__(name=name, **kwargs)
- self.true_positives = K.variable(value=0, dtype='int32')
- self.stateful = True
-
- def reset_states(self):
- K.set_value(self.true_positives, 0)
-
- def __call__(self, y_true, y_pred):
- """Computes the number of true positives in a batch.
-
- Args:
- y_true: Tensor, batch_wise labels
- y_pred: Tensor, batch_wise predictions
-
- Returns:
- The total number of true positives seen this epoch at the
- completion of the batch.
- """
- y_true = math_ops.cast(y_true, 'int32')
- y_pred = math_ops.cast(math_ops.round(y_pred), 'int32')
- correct_preds = math_ops.cast(math_ops.equal(y_pred, y_true), 'int32')
- true_pos = math_ops.cast(
- math_ops.reduce_sum(correct_preds * y_true), 'int32')
- current_true_pos = self.true_positives * 1
- self.add_update(
- state_ops.assign_add(self.true_positives, true_pos),
- inputs=[y_true, y_pred])
- return current_true_pos + true_pos
-
- metric_fn = BinaryTruePositives()
- config = metrics.serialize(metric_fn)
- metric_fn = metrics.deserialize(
- config, custom_objects={'BinaryTruePositives': BinaryTruePositives})
-
- # Test on simple model
- inputs = layers.Input(shape=(2,))
- outputs = layers.Dense(1, activation='sigmoid')(inputs)
- model = Model(inputs, outputs)
- model.compile(optimizer='sgd',
- loss='binary_crossentropy',
- metrics=['acc', metric_fn])
-
- # Test fit, evaluate
- samples = 100
- x = np.random.random((samples, 2))
- y = np.random.randint(2, size=(samples, 1))
- val_samples = 10
- val_x = np.random.random((val_samples, 2))
- val_y = np.random.randint(2, size=(val_samples, 1))
-
- history = model.fit(x, y,
- epochs=1,
- batch_size=10,
- validation_data=(val_x, val_y))
- outs = model.evaluate(x, y, batch_size=10)
- preds = model.predict(x)
-
- def ref_true_pos(y_true, y_pred):
- return np.sum(np.logical_and(y_pred > 0.5, y_true == 1))
-
- # Test correctness (e.g. updates should have been run)
- self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)
-
- # Test correctness of the validation metric computation
- val_preds = model.predict(val_x)
- val_outs = model.evaluate(val_x, val_y, batch_size=10)
- self.assertAllClose(
- val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
- self.assertAllClose(
- val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
-
- # Test with generators
- gen = [(np.array([x0]), np.array([y0])) for x0, y0 in zip(x, y)]
- val_gen = [(np.array([x0]), np.array([y0]))
- for x0, y0 in zip(val_x, val_y)]
- history = model.fit_generator(iter(gen),
- epochs=1,
- steps_per_epoch=samples,
- validation_data=iter(val_gen),
- validation_steps=val_samples)
- outs = model.evaluate_generator(iter(gen), steps=samples)
- preds = model.predict_generator(iter(gen), steps=samples)
-
- # Test correctness of the metric results
- self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)
-
- # Test correctness of the validation metric computation
- val_preds = model.predict_generator(iter(val_gen), steps=val_samples)
- val_outs = model.evaluate_generator(iter(val_gen), steps=val_samples)
- self.assertAllClose(
- val_outs[2], ref_true_pos(val_y, val_preds), atol=1e-5)
- self.assertAllClose(
- val_outs[2], history.history['val_true_positives'][-1], atol=1e-5)
-
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def test_mean(self):
m = metrics.Mean(name='my_mean')
@@ -320,19 +209,19 @@
m = metrics.Mean()
v = array_ops.placeholder(dtypes.float32)
w = array_ops.placeholder(dtypes.float32)
- sess.run(variables.variables_initializer(m.variables))
+ self.evaluate(variables.variables_initializer(m.variables))
# check __call__()
result_t = m(v, sample_weight=w)
result = sess.run(result_t, feed_dict=({v: 100, w: 0.5}))
- self.assertEqual(sess.run(m.total), 50)
- self.assertEqual(sess.run(m.count), 0.5)
+ self.assertEqual(self.evaluate(m.total), 50)
+ self.assertEqual(self.evaluate(m.count), 0.5)
self.assertEqual(result, 50 / 0.5)
# check update_state() and result()
result = sess.run(result_t, feed_dict=({v: [1, 5], w: [1, 0.2]}))
- self.assertAlmostEqual(sess.run(m.total), 52, 2) # 50 + 1 + 5 * 0.2
- self.assertAlmostEqual(sess.run(m.count), 1.7, 2) # 0.5 + 1.2
+ self.assertAlmostEqual(self.evaluate(m.total), 52, 2) # 50 + 1 + 5 * 0.2
+ self.assertAlmostEqual(self.evaluate(m.count), 1.7, 2) # 0.5 + 1.2
self.assertAlmostEqual(result, 52 / 1.7, 2)
@test_util.run_in_graph_and_eager_modes
@@ -367,6 +256,28 @@
self.assertEqual(3, self.evaluate(restore_mean.count))
@test_util.run_in_graph_and_eager_modes
+ def test_accuracy(self):
+ acc_obj = metrics.Accuracy(name='my acc')
+
+ # check config
+ self.assertEqual(acc_obj.name, 'my acc')
+ self.assertTrue(acc_obj.stateful)
+ self.assertEqual(len(acc_obj.variables), 2)
+ self.assertEqual(acc_obj.dtype, dtypes.float32)
+ self.evaluate(variables.variables_initializer(acc_obj.variables))
+
+ # verify that correct value is returned
+ update_op = acc_obj.update_state([[1], [2], [3], [4]], [[1], [2], [3], [4]])
+ self.evaluate(update_op)
+ result = self.evaluate(acc_obj.result())
+ self.assertEqual(result, 1) # 2/2
+
+ # check with sample_weight
+ result_t = acc_obj([[2], [1]], [[2], [0]], sample_weight=[[0.5], [0.2]])
+ result = self.evaluate(result_t)
+ self.assertAlmostEqual(result, 0.96, 2) # 4.5/4.7
+
+ @test_util.run_in_graph_and_eager_modes
def test_binary_accuracy(self):
acc_obj = metrics.BinaryAccuracy(name='my acc')
@@ -399,11 +310,6 @@
result = self.evaluate(result_t)
self.assertAlmostEqual(result, 0.67, 2) # 4.5/6.7
- # check incompatible shapes
- with self.assertRaisesRegexp(ValueError,
- r'Shapes \(1,\) and \(2,\) are incompatible'):
- acc_obj.update_state([1, 1], [1])
-
@test_util.run_in_graph_and_eager_modes
def test_binary_accuracy_threshold(self):
acc_obj = metrics.BinaryAccuracy(threshold=0.7)
@@ -437,46 +343,28 @@
self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7
@test_util.run_in_graph_and_eager_modes
- def test_invalid_result(self):
+ def test_sparse_categorical_accuracy(self):
+ acc_obj = metrics.SparseCategoricalAccuracy(name='my acc')
- class InvalidResult(metrics.Metric):
+ # check config
+ self.assertEqual(acc_obj.name, 'my acc')
+ self.assertTrue(acc_obj.stateful)
+ self.assertEqual(len(acc_obj.variables), 2)
+ self.assertEqual(acc_obj.dtype, dtypes.float32)
+ self.evaluate(variables.variables_initializer(acc_obj.variables))
- def __init__(self, name='invalid-result', dtype=dtypes.float64):
- super(InvalidResult, self).__init__(name=name, dtype=dtype)
+ # verify that correct value is returned
+ update_op = acc_obj.update_state([[2], [1]],
+ [[0.1, 0.1, 0.8], [0.05, 0.95, 0]])
+ self.evaluate(update_op)
+ result = self.evaluate(acc_obj.result())
+ self.assertEqual(result, 1) # 2/2
- def update_state(self, *args, **kwargs):
- pass
-
- def result(self):
- return 1
-
- invalid_result_obj = InvalidResult()
- with self.assertRaisesRegexp(
- TypeError,
- 'Metric invalid-result\'s result must be a Tensor or Operation, given:'
- ):
- invalid_result_obj.result()
-
- @test_util.run_in_graph_and_eager_modes
- def test_invalid_update(self):
-
- class InvalidUpdate(metrics.Metric):
-
- def __init__(self, name='invalid-update', dtype=dtypes.float64):
- super(InvalidUpdate, self).__init__(name=name, dtype=dtype)
-
- def update_state(self, *args, **kwargs):
- return [1]
-
- def result(self):
- pass
-
- invalid_update_obj = InvalidUpdate()
- with self.assertRaisesRegexp(
- TypeError,
- 'Metric invalid-update\'s update must be a Tensor or Operation, given:'
- ):
- invalid_update_obj.update_state()
+ # check with sample_weight
+ result_t = acc_obj([[2], [1]], [[0.1, 0.1, 0.8], [0.05, 0, 0.95]],
+ [[0.5], [0.2]])
+ result = self.evaluate(result_t)
+ self.assertAlmostEqual(result, 0.93, 2) # 2.5/2.7
@test_util.run_all_in_graph_and_eager_modes
@@ -500,7 +388,7 @@
update_op = fp_obj.update_state(y_true, y_pred)
self.evaluate(update_op)
result = fp_obj.result()
- self.assertAllClose([7.], result)
+ self.assertAllClose(7., result)
def test_weighted(self):
fp_obj = metrics.FalsePositives()
@@ -511,7 +399,7 @@
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
result = fp_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAllClose([14.], self.evaluate(result))
+ self.assertAllClose(14., self.evaluate(result))
def test_unweighted_with_thresholds(self):
fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85])
@@ -569,7 +457,7 @@
update_op = fn_obj.update_state(y_true, y_pred)
self.evaluate(update_op)
result = fn_obj.result()
- self.assertAllClose([3.], result)
+ self.assertAllClose(3., result)
def test_weighted(self):
fn_obj = metrics.FalseNegatives()
@@ -580,7 +468,7 @@
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
result = fn_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAllClose([5.], self.evaluate(result))
+ self.assertAllClose(5., self.evaluate(result))
def test_unweighted_with_thresholds(self):
fn_obj = metrics.FalseNegatives(thresholds=[0.15, 0.5, 0.85])
@@ -631,7 +519,7 @@
update_op = tn_obj.update_state(y_true, y_pred)
self.evaluate(update_op)
result = tn_obj.result()
- self.assertAllClose([3.], result)
+ self.assertAllClose(3., result)
def test_weighted(self):
tn_obj = metrics.TrueNegatives()
@@ -642,7 +530,7 @@
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
result = tn_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAllClose([4.], self.evaluate(result))
+ self.assertAllClose(4., self.evaluate(result))
def test_unweighted_with_thresholds(self):
tn_obj = metrics.TrueNegatives(thresholds=[0.15, 0.5, 0.85])
@@ -693,7 +581,7 @@
update_op = tp_obj.update_state(y_true, y_pred)
self.evaluate(update_op)
result = tp_obj.result()
- self.assertAllClose([7.], result)
+ self.assertAllClose(7., result)
def test_weighted(self):
tp_obj = metrics.TruePositives()
@@ -704,7 +592,7 @@
(0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
result = tp_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAllClose([12.], self.evaluate(result))
+ self.assertAllClose(12., self.evaluate(result))
def test_unweighted_with_thresholds(self):
tp_obj = metrics.TruePositives(thresholds=[0.15, 0.5, 0.85])
@@ -733,5 +621,406 @@
self.assertAllClose([222., 111., 37.], self.evaluate(result))
+@test_util.run_all_in_graph_and_eager_modes
+class PrecisionTest(test.TestCase):
+
+ def test_config(self):
+ p_obj = metrics.Precision(name='my_precision', thresholds=[0.4, 0.9])
+ self.assertEqual(p_obj.name, 'my_precision')
+ self.assertLen(p_obj.variables, 2)
+ self.assertEqual([v.name for v in p_obj.variables],
+ ['true_positives:0', 'false_positives:0'])
+ self.assertEqual(p_obj.thresholds, [0.4, 0.9])
+
+ def test_value_is_idempotent(self):
+ p_obj = metrics.Precision(thresholds=[0.3, 0.72])
+ y_pred = random_ops.random_uniform(shape=(10, 3))
+ y_true = random_ops.random_uniform(shape=(10, 3))
+ update_op = p_obj.update_state(y_true, y_pred)
+ self.evaluate(variables.variables_initializer(p_obj.variables))
+
+ # Run several updates.
+ for _ in range(10):
+ self.evaluate(update_op)
+
+ # Then verify idempotency.
+ initial_precision = self.evaluate(p_obj.result())
+ for _ in range(10):
+ self.assertArrayNear(initial_precision, self.evaluate(p_obj.result()),
+ 1e-3)
+
+ def test_unweighted(self):
+ p_obj = metrics.Precision()
+ y_pred = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
+ y_true = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ self.evaluate(variables.variables_initializer(p_obj.variables))
+ result = p_obj(y_true, y_pred)
+ self.assertAlmostEqual(0.5, self.evaluate(result))
+
+ def test_unweighted_all_incorrect(self):
+ p_obj = metrics.Precision(thresholds=[0.5])
+ inputs = np.random.randint(0, 2, size=(100, 1))
+ y_pred = constant_op.constant(inputs)
+ y_true = constant_op.constant(1 - inputs)
+ self.evaluate(variables.variables_initializer(p_obj.variables))
+ result = p_obj(y_true, y_pred)
+ self.assertAlmostEqual(0, self.evaluate(result))
+
+ def test_weighted(self):
+ p_obj = metrics.Precision()
+ y_pred = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
+ y_true = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
+ self.evaluate(variables.variables_initializer(p_obj.variables))
+ result = p_obj(
+ y_true,
+ y_pred,
+ sample_weight=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
+ weighted_tp = 3.0 + 4.0
+ weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
+ expected_precision = weighted_tp / weighted_positives
+ self.assertAlmostEqual(expected_precision, self.evaluate(result))
+
+ def test_div_by_zero(self):
+ p_obj = metrics.Precision()
+ y_pred = constant_op.constant([0, 0, 0, 0])
+ y_true = constant_op.constant([0, 0, 0, 0])
+ self.evaluate(variables.variables_initializer(p_obj.variables))
+ result = p_obj(y_true, y_pred)
+ self.assertEqual(0, self.evaluate(result))
+
+ def test_unweighted_with_threshold(self):
+ p_obj = metrics.Precision(thresholds=[0.5, 0.7])
+ y_pred = constant_op.constant([1, 0, 0.6, 0], shape=(1, 4))
+ y_true = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ self.evaluate(variables.variables_initializer(p_obj.variables))
+ result = p_obj(y_true, y_pred)
+ self.assertArrayNear([0.5, 0.], self.evaluate(result), 0)
+
+ def test_weighted_with_threshold(self):
+ p_obj = metrics.Precision(thresholds=[0.5, 1.])
+ y_true = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ y_pred = constant_op.constant([[1, 0], [0.6, 0]],
+ shape=(2, 2),
+ dtype=dtypes.float32)
+ weights = constant_op.constant([[4, 0], [3, 1]],
+ shape=(2, 2),
+ dtype=dtypes.float32)
+ self.evaluate(variables.variables_initializer(p_obj.variables))
+ result = p_obj(y_true, y_pred, sample_weight=weights)
+ weighted_tp = 0 + 3.
+ weighted_positives = (0 + 3.) + (4. + 0.)
+ expected_precision = weighted_tp / weighted_positives
+ self.assertArrayNear([expected_precision, 0], self.evaluate(result), 1e-3)
+
+ def test_multiple_updates(self):
+ p_obj = metrics.Precision(thresholds=[0.5, 1.])
+ y_true = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ y_pred = constant_op.constant([[1, 0], [0.6, 0]],
+ shape=(2, 2),
+ dtype=dtypes.float32)
+ weights = constant_op.constant([[4, 0], [3, 1]],
+ shape=(2, 2),
+ dtype=dtypes.float32)
+ self.evaluate(variables.variables_initializer(p_obj.variables))
+ update_op = p_obj.update_state(y_true, y_pred, sample_weight=weights)
+ for _ in range(2):
+ self.evaluate(update_op)
+
+ weighted_tp = (0 + 3.) + (0 + 3.)
+ weighted_positives = ((0 + 3.) + (4. + 0.)) + ((0 + 3.) + (4. + 0.))
+ expected_precision = weighted_tp / weighted_positives
+ self.assertArrayNear([expected_precision, 0], self.evaluate(p_obj.result()),
+ 1e-3)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class RecallTest(test.TestCase):
+
+ def test_config(self):
+ r_obj = metrics.Recall(name='my_recall', thresholds=[0.4, 0.9])
+ self.assertEqual(r_obj.name, 'my_recall')
+ self.assertLen(r_obj.variables, 2)
+ self.assertEqual([v.name for v in r_obj.variables],
+ ['true_positives:0', 'false_negatives:0'])
+ self.assertEqual(r_obj.thresholds, [0.4, 0.9])
+
+ def test_value_is_idempotent(self):
+ r_obj = metrics.Recall(thresholds=[0.3, 0.72])
+ y_pred = random_ops.random_uniform(shape=(10, 3))
+ y_true = random_ops.random_uniform(shape=(10, 3))
+ update_op = r_obj.update_state(y_true, y_pred)
+ self.evaluate(variables.variables_initializer(r_obj.variables))
+
+ # Run several updates.
+ for _ in range(10):
+ self.evaluate(update_op)
+
+ # Then verify idempotency.
+ initial_recall = self.evaluate(r_obj.result())
+ for _ in range(10):
+ self.assertArrayNear(initial_recall, self.evaluate(r_obj.result()), 1e-3)
+
+ def test_unweighted(self):
+ r_obj = metrics.Recall()
+ y_pred = constant_op.constant([1, 0, 1, 0], shape=(1, 4))
+ y_true = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ self.evaluate(variables.variables_initializer(r_obj.variables))
+ result = r_obj(y_true, y_pred)
+ self.assertAlmostEqual(0.5, self.evaluate(result))
+
+ def test_unweighted_all_incorrect(self):
+ r_obj = metrics.Recall(thresholds=[0.5])
+ inputs = np.random.randint(0, 2, size=(100, 1))
+ y_pred = constant_op.constant(inputs)
+ y_true = constant_op.constant(1 - inputs)
+ self.evaluate(variables.variables_initializer(r_obj.variables))
+ result = r_obj(y_true, y_pred)
+ self.assertAlmostEqual(0, self.evaluate(result))
+
+ def test_weighted(self):
+ r_obj = metrics.Recall()
+ y_pred = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
+ y_true = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
+ self.evaluate(variables.variables_initializer(r_obj.variables))
+ result = r_obj(
+ y_true,
+ y_pred,
+ sample_weight=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
+ weighted_tp = 3.0 + 1.0
+ weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
+ expected_recall = weighted_tp / weighted_t
+ self.assertAlmostEqual(expected_recall, self.evaluate(result))
+
+ def test_div_by_zero(self):
+ r_obj = metrics.Recall()
+ y_pred = constant_op.constant([0, 0, 0, 0])
+ y_true = constant_op.constant([0, 0, 0, 0])
+ self.evaluate(variables.variables_initializer(r_obj.variables))
+ result = r_obj(y_true, y_pred)
+ self.assertEqual(0, self.evaluate(result))
+
+ def test_unweighted_with_threshold(self):
+ r_obj = metrics.Recall(thresholds=[0.5, 0.7])
+ y_pred = constant_op.constant([1, 0, 0.6, 0], shape=(1, 4))
+ y_true = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
+ self.evaluate(variables.variables_initializer(r_obj.variables))
+ result = r_obj(y_true, y_pred)
+ self.assertArrayNear([0.5, 0.], self.evaluate(result), 0)
+
+ def test_weighted_with_threshold(self):
+ r_obj = metrics.Recall(thresholds=[0.5, 1.])
+ y_true = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ y_pred = constant_op.constant([[1, 0], [0.6, 0]],
+ shape=(2, 2),
+ dtype=dtypes.float32)
+ weights = constant_op.constant([[1, 4], [3, 2]],
+ shape=(2, 2),
+ dtype=dtypes.float32)
+ self.evaluate(variables.variables_initializer(r_obj.variables))
+ result = r_obj(y_true, y_pred, sample_weight=weights)
+ weighted_tp = 0 + 3.
+ weighted_positives = (0 + 3.) + (4. + 0.)
+ expected_recall = weighted_tp / weighted_positives
+ self.assertArrayNear([expected_recall, 0], self.evaluate(result), 1e-3)
+
+ def test_multiple_updates(self):
+ r_obj = metrics.Recall(thresholds=[0.5, 1.])
+ y_true = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
+ y_pred = constant_op.constant([[1, 0], [0.6, 0]],
+ shape=(2, 2),
+ dtype=dtypes.float32)
+ weights = constant_op.constant([[1, 4], [3, 2]],
+ shape=(2, 2),
+ dtype=dtypes.float32)
+ self.evaluate(variables.variables_initializer(r_obj.variables))
+ update_op = r_obj.update_state(y_true, y_pred, sample_weight=weights)
+ for _ in range(2):
+ self.evaluate(update_op)
+
+ weighted_tp = (0 + 3.) + (0 + 3.)
+ weighted_positives = ((0 + 3.) + (4. + 0.)) + ((0 + 3.) + (4. + 0.))
+ expected_recall = weighted_tp / weighted_positives
+ self.assertArrayNear([expected_recall, 0], self.evaluate(r_obj.result()),
+ 1e-3)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class SensitivityAtSpecificityTest(test.TestCase, parameterized.TestCase):
+
+ def test_config(self):
+ s_obj = metrics.SensitivityAtSpecificity(
+ 0.4, num_thresholds=100, name='sensitivity_at_specificity_1')
+ self.assertEqual(s_obj.name, 'sensitivity_at_specificity_1')
+ self.assertLen(s_obj.variables, 4)
+ self.assertEqual(s_obj.value, 0.4)
+ self.assertLen(s_obj.thresholds, 100)
+
+ def test_value_is_idempotent(self):
+ s_obj = metrics.SensitivityAtSpecificity(0.7)
+ y_pred = random_ops.random_uniform((10, 3),
+ maxval=1,
+ dtype=dtypes.float32,
+ seed=1)
+ y_true = random_ops.random_uniform((10, 3),
+ maxval=2,
+ dtype=dtypes.int64,
+ seed=1)
+ update_op = s_obj.update_state(y_true, y_pred)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+
+ # Run several updates.
+ for _ in range(10):
+ self.evaluate(update_op)
+
+ # Then verify idempotency.
+ initial_sensitivity = self.evaluate(s_obj.result())
+ for _ in range(10):
+ self.assertAlmostEqual(initial_sensitivity, self.evaluate(s_obj.result()),
+ 1e-3)
+
+ def test_unweighted_all_correct(self):
+ s_obj = metrics.SensitivityAtSpecificity(0.7)
+ inputs = np.random.randint(0, 2, size=(100, 1))
+ y_pred = constant_op.constant(inputs, dtype=dtypes.float32)
+ y_true = constant_op.constant(inputs)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+ result = s_obj(y_true, y_pred)
+ self.assertAlmostEqual(1, self.evaluate(result))
+
+ def test_unweighted_high_specificity(self):
+ s_obj = metrics.SensitivityAtSpecificity(0.8)
+ pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
+ label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+
+ y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
+ y_true = constant_op.constant(label_values)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+ result = s_obj(y_true, y_pred)
+ self.assertAlmostEqual(0.8, self.evaluate(result))
+
+ def test_unweighted_low_specificity(self):
+ s_obj = metrics.SensitivityAtSpecificity(0.4)
+ pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
+ label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+
+ y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
+ y_true = constant_op.constant(label_values)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+ result = s_obj(y_true, y_pred)
+ self.assertAlmostEqual(0.6, self.evaluate(result))
+
+ @parameterized.parameters([dtypes.bool, dtypes.int32, dtypes.float32])
+ def test_weighted(self, label_dtype):
+ s_obj = metrics.SensitivityAtSpecificity(0.4)
+ pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
+ label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+ weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+
+ y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
+ y_true = math_ops.cast(label_values, dtype=label_dtype)
+ weights = constant_op.constant(weight_values)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+ result = s_obj(y_true, y_pred, sample_weight=weights)
+ self.assertAlmostEqual(0.675, self.evaluate(result))
+
+ def test_invalid_specificity(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'`specificity` must be in the range \[0, 1\].'):
+ metrics.SensitivityAtSpecificity(-1)
+
+ def test_invalid_num_thresholds(self):
+ with self.assertRaisesRegexp(ValueError, '`num_thresholds` must be > 0.'):
+ metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class SpecificityAtSensitivityTest(test.TestCase, parameterized.TestCase):
+
+ def test_config(self):
+ s_obj = metrics.SpecificityAtSensitivity(
+ 0.4, num_thresholds=100, name='specificity_at_sensitivity_1')
+ self.assertEqual(s_obj.name, 'specificity_at_sensitivity_1')
+ self.assertLen(s_obj.variables, 4)
+ self.assertEqual(s_obj.value, 0.4)
+ self.assertLen(s_obj.thresholds, 100)
+
+ def test_value_is_idempotent(self):
+ s_obj = metrics.SpecificityAtSensitivity(0.7)
+ y_pred = random_ops.random_uniform((10, 3),
+ maxval=1,
+ dtype=dtypes.float32,
+ seed=1)
+ y_true = random_ops.random_uniform((10, 3),
+ maxval=2,
+ dtype=dtypes.int64,
+ seed=1)
+ update_op = s_obj.update_state(y_true, y_pred)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+
+ # Run several updates.
+ for _ in range(10):
+ self.evaluate(update_op)
+
+ # Then verify idempotency.
+ initial_specificity = self.evaluate(s_obj.result())
+ for _ in range(10):
+ self.assertAlmostEqual(initial_specificity, self.evaluate(s_obj.result()),
+ 1e-3)
+
+ def test_unweighted_all_correct(self):
+ s_obj = metrics.SpecificityAtSensitivity(0.7)
+ inputs = np.random.randint(0, 2, size=(100, 1))
+ y_pred = constant_op.constant(inputs, dtype=dtypes.float32)
+ y_true = constant_op.constant(inputs)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+ result = s_obj(y_true, y_pred)
+ self.assertAlmostEqual(1, self.evaluate(result))
+
+ def test_unweighted_high_sensitivity(self):
+ s_obj = metrics.SpecificityAtSensitivity(0.8)
+ pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
+ label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+
+ y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
+ y_true = constant_op.constant(label_values)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+ result = s_obj(y_true, y_pred)
+ self.assertAlmostEqual(0.4, self.evaluate(result))
+
+ def test_unweighted_low_sensitivity(self):
+ s_obj = metrics.SpecificityAtSensitivity(0.4)
+ pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
+ label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+
+ y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
+ y_true = constant_op.constant(label_values)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+ result = s_obj(y_true, y_pred)
+ self.assertAlmostEqual(0.6, self.evaluate(result))
+
+ @parameterized.parameters([dtypes.bool, dtypes.int32, dtypes.float32])
+ def test_weighted(self, label_dtype):
+ s_obj = metrics.SpecificityAtSensitivity(0.4)
+ pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
+ label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
+ weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+
+ y_pred = constant_op.constant(pred_values, dtype=dtypes.float32)
+ y_true = math_ops.cast(label_values, dtype=label_dtype)
+ weights = constant_op.constant(weight_values)
+ self.evaluate(variables.variables_initializer(s_obj.variables))
+ result = s_obj(y_true, y_pred, sample_weight=weights)
+ self.assertAlmostEqual(0.4, self.evaluate(result))
+
+ def test_invalid_sensitivity(self):
+ with self.assertRaisesRegexp(
+ ValueError, r'`sensitivity` must be in the range \[0, 1\].'):
+ metrics.SpecificityAtSensitivity(-1)
+
+ def test_invalid_num_thresholds(self):
+ with self.assertRaisesRegexp(ValueError, '`num_thresholds` must be > 0.'):
+ metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 3a0c51b..4813b80 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -100,17 +100,19 @@
input_tensors = list(input_tensors)
input_tensors = generic_utils.to_list(input_tensors)
input_tensors_ = []
- for i, x in enumerate(input_tensors):
- if not K.is_keras_tensor(x):
- name = model._input_layers[i].name
- input_tensor = Input(tensor=x, name='input_wrapper_for_' + name)
+ for i in range(len(input_tensors)):
+ input_tensor = input_tensors[i]
+ if not K.is_keras_tensor(input_tensor):
+ original_input_layer = model._input_layers[i]
+ name = original_input_layer.name
+ input_tensor = Input(tensor=input_tensor,
+ name='input_wrapper_for_' + name)
input_tensors_.append(input_tensor)
# Cache newly created input layer.
- original_input_layer = x._keras_history[0]
newly_created_input_layer = input_tensor._keras_history[0]
layer_map[original_input_layer] = newly_created_input_layer
else:
- input_tensors_.append(x)
+ input_tensors_.append(input_tensor)
input_tensors = input_tensors_
for x, y in zip(model.inputs, input_tensors):
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 4b6bb74..23321a2 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -26,10 +26,12 @@
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import metrics
from tensorflow.python.keras import models
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@@ -219,6 +221,33 @@
with self.assertRaises(ValueError):
keras.models._clone_sequential_model(seq_model, input_tensors=y)
+ def test_functional_cloning_does_not_create_unnecessary_placeholders(self):
+ with ops.Graph().as_default():
+ x = keras.Input((4,))
+ y = keras.layers.Dense(4)(x)
+ model = keras.models.Model(x, y)
+ graph = ops.Graph()
+ with graph.as_default():
+ x = array_ops.ones((10, 4))
+ _ = keras.models.clone_model(model, input_tensors=[x])
+ has_placeholder = _has_placeholder(graph)
+ self.assertFalse(has_placeholder)
+
+ def test_sequential_cloning_does_not_create_unnecessary_placeholders(self):
+ with ops.Graph().as_default():
+ model = keras.models.Sequential([keras.layers.Dense(4)])
+ graph = ops.Graph()
+ with graph.as_default():
+ x = array_ops.ones((10, 4))
+ _ = keras.models.clone_model(model, input_tensors=[x])
+ has_placeholder = _has_placeholder(graph)
+ self.assertFalse(has_placeholder)
+
+
+def _has_placeholder(graph):
+ ops_types = [op.type for op in graph.get_operations()]
+ return any('Placeholder' in s for s in ops_types)
+
class CheckpointingTests(test.TestCase):
diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD
index 7defc7d..6b80578 100644
--- a/tensorflow/python/keras/optimizer_v2/BUILD
+++ b/tensorflow/python/keras/optimizer_v2/BUILD
@@ -17,6 +17,7 @@
"adagrad.py",
"adam.py",
"adamax.py",
+ "ftrl.py",
"gradient_descent.py",
"nadam.py",
"optimizer_v2.py",
@@ -113,6 +114,25 @@
)
cuda_py_test(
+ name = "ftrl_test",
+ size = "medium",
+ srcs = ["ftrl_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:resources",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ ],
+ shard_count = 4,
+)
+
+cuda_py_test(
name = "gradient_descent_test",
size = "medium",
srcs = ["gradient_descent_test.py"],
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py
index 21a3f06..e1d7ecb 100644
--- a/tensorflow/python/keras/optimizer_v2/adadelta.py
+++ b/tensorflow/python/keras/optimizer_v2/adadelta.py
@@ -19,7 +19,6 @@
from __future__ import print_function
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import math_ops
from tensorflow.python.training import training_ops
@@ -55,7 +54,8 @@
learning_rate=0.001,
rho=0.95,
epsilon=1e-7,
- name='Adadelta'):
+ name='Adadelta',
+ **kwargs):
"""Construct a new Adadelta optimizer.
Adadelta is a more robust extension of Adagrad that adapts learning rates
@@ -73,6 +73,7 @@
to better conditioning the grad update.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adadelta".
+ **kwargs: keyword arguments. Allowed to be {`decay`}
@compatibility(eager)
When eager execution is enabled, `learning_rate`, `rho`, and `epsilon` can
@@ -81,8 +82,9 @@
invocations of optimizer functions.
@end_compatibility
"""
- super(Adadelta, self).__init__(name)
+ super(Adadelta, self).__init__(name, **kwargs)
self._set_hyper('learning_rate', learning_rate)
+ self._set_hyper('decay', self._initial_decay)
self._set_hyper('rho', rho)
self._set_hyper('epsilon', epsilon)
@@ -92,28 +94,32 @@
self.add_slot(v, 'accum_var')
def _resource_apply_dense(self, grad, var):
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
accum_grad = self.get_slot(var, 'accum_grad')
accum_var = self.get_slot(var, 'accum_var')
return training_ops.resource_apply_adadelta(
var.handle,
accum_grad.handle,
accum_var.handle,
- math_ops.cast(self._get_hyper('learning_rate'), grad.dtype.base_dtype),
- math_ops.cast(self._get_hyper('rho'), grad.dtype.base_dtype),
- math_ops.cast(self._get_hyper('epsilon'), grad.dtype.base_dtype),
+ lr_t,
+ self._get_hyper('rho', var_dtype),
+ self._get_hyper('epsilon', var_dtype),
grad,
use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices):
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
accum_grad = self.get_slot(var, 'accum_grad')
accum_var = self.get_slot(var, 'accum_var')
return training_ops.resource_sparse_apply_adadelta(
var.handle,
accum_grad.handle,
accum_var.handle,
- math_ops.cast(self._get_hyper('learning_rate'), grad.dtype.base_dtype),
- math_ops.cast(self._get_hyper('rho'), grad.dtype.base_dtype),
- math_ops.cast(self._get_hyper('epsilon'), grad.dtype.base_dtype),
+ lr_t,
+ self._get_hyper('rho', var_dtype),
+ self._get_hyper('epsilon', var_dtype),
grad,
indices,
use_locking=self._use_locking)
@@ -122,6 +128,7 @@
config = super(Adadelta, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
+ 'decay': self._serialize_hyperparameter('decay'),
'rho': self._serialize_hyperparameter('rho'),
'epsilon': self._serialize_hyperparameter('epsilon'),
})
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
index 7d090e8..0896f95 100644
--- a/tensorflow/python/keras/optimizer_v2/adagrad.py
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -21,7 +21,6 @@
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -55,7 +54,8 @@
learning_rate=0.001,
initial_accumulator_value=0.1,
epsilon=1e-7,
- name='Adagrad'):
+ name='Adagrad',
+ **kwargs):
"""Construct a new Adagrad optimizer.
Args:
@@ -66,6 +66,7 @@
Starting value for the accumulators, must be positive.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adagrad".
+ **kwargs: keyword arguments. Allowed to be {`decay`}
Raises:
ValueError: If the `initial_accumulator_value` or `epsilon` is invalid.
@@ -82,8 +83,9 @@
initial_accumulator_value)
if epsilon < 1e-7:
raise ValueError('epsilon must be larger than 1e-7: %s' % epsilon)
- super(Adagrad, self).__init__(name)
+ super(Adagrad, self).__init__(name, **kwargs)
self._set_hyper('learning_rate', learning_rate)
+ self._set_hyper('decay', self._initial_decay)
self._initial_accumulator_value = initial_accumulator_value
self._set_hyper('epsilon', epsilon)
@@ -94,25 +96,16 @@
self._initial_accumulator_value, dtype=dtype)
self.add_slot(var, 'accumulator', init)
- def _init_constant_op(self, v, dtype):
- def init():
- # Use a Tensor instead of initializer if variable does not have
- # static shape.
- init_constant = gen_array_ops.fill(array_ops.shape(v),
- self._initial_accumulator_value)
- return math_ops.cast(init_constant, dtype)
- return init
-
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
- learning_rate = math_ops.cast(self._get_hyper('learning_rate'), var_dtype)
- epsilon = math_ops.cast(self._get_hyper('epsilon'), var_dtype)
+ lr_t = self._decayed_lr(var_dtype)
+ epsilon = self._get_hyper('epsilon', var_dtype)
acc = self.get_slot(var, 'accumulator')
acc_t = state_ops.assign_add(
acc, math_ops.square(grad), use_locking=self._use_locking)
var_update = state_ops.assign_sub(
- var, learning_rate * grad / (math_ops.sqrt(acc_t) + epsilon))
+ var, lr_t * grad / (math_ops.sqrt(acc_t) + epsilon))
return var_update
def _resource_apply_sparse(self, grad, var, indices):
@@ -123,21 +116,21 @@
return x.value()
var_dtype = var.dtype.base_dtype
- learning_rate = math_ops.cast(self._get_hyper('learning_rate'), var_dtype)
- epsilon = math_ops.cast(self._get_hyper('epsilon'), var_dtype)
+ lr_t = self._decayed_lr(var_dtype)
+ epsilon = self._get_hyper('epsilon', var_dtype)
acc = self.get_slot(var, 'accumulator')
acc_t = _resource_scatter_add(acc, indices, math_ops.square(grad))
acc_t_slice = array_ops.gather(acc_t, indices)
var_update = _resource_scatter_add(
- var, indices,
- -learning_rate * grad / (math_ops.sqrt(acc_t_slice) + epsilon))
+ var, indices, -lr_t * grad / (math_ops.sqrt(acc_t_slice) + epsilon))
return var_update
def get_config(self):
config = super(Adagrad, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
+ 'decay': self._serialize_hyperparameter('decay'),
'initial_accumulator_value': self._initial_accumulator_value,
'epsilon': self._serialize_hyperparameter('epsilon'),
})
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
index 7d0f55c..5ddeb1a 100644
--- a/tensorflow/python/keras/optimizer_v2/adagrad_test.py
+++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
@@ -116,6 +116,50 @@
with context.eager_mode():
self.doTestBasic(use_callable_params=True)
+ def testBasicWithLearningRateDecay(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ learning_rate = 3.0
+ decay = 0.5
+
+ ada_opt = adagrad.Adagrad(learning_rate, decay=decay)
+
+ accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+
+ if not context.executing_eagerly():
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+
+ # Fetch params to validate initial values
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([3.0, 4.0], v1_val)
+
+ # Run 3 steps of adagrad
+ for t in range(3):
+ if not context.executing_eagerly():
+ self.evaluate(ada_update)
+ else:
+ ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ lr_np = learning_rate / (1 + decay * t)
+ var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np,
+ grads0_np, lr_np)
+ var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np,
+ grads1_np, lr_np)
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
index 962680f..a3f1290 100644
--- a/tensorflow/python/keras/optimizer_v2/adam.py
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -35,9 +35,13 @@
requirement, invariant to diagonal rescaling of gradients, and is well suited
for problems that are large in terms of data/parameters'.
+ Note, amsgrad is currently not supported and the argument can only be False.
+
# References
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
+ For AMSGrad see [Reddi et al., 2-18]
+ (https://openreview.net/pdf?id=ryQu7f-RZ)
"""
def __init__(self,
@@ -45,26 +49,48 @@
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
- name='Adam'):
+ amsgrad=False,
+ name='Adam',
+ **kwargs):
r"""Construct a new Adam optimizer.
- Initialization:
+ If amsgrad = False:
+ Initialization:
- $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
- $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
- $$t := 0 \text{(Initialize timestep)}$$
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
- The update rule for `variable` with gradient `g` uses an optimization
- described at the end of section2 of the paper:
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
- $$t := t + 1$$
- $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
+ $$t := t + 1$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
- $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
- $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
- $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
+ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+ $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
- The default value of 1e-8 for epsilon might not be a good default in
+ If amsgrad = True:
+ Initialization:
+
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$v_hat_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
+
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
+
+ $$t := t + 1$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
+
+ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+ $$v_hat_t := max(v_hat_{t-1}, v_t)
+ $$variable := variable - lr_t * m_t / (\sqrt{v_hat_t} + \epsilon)$$
+
+ The default value of 1e-7 for epsilon might not be a good default in
general. For example, when training an Inception network on ImageNet a
current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
formulation just before Section 2.1 of the Kingma and Ba paper rather than
@@ -89,19 +115,27 @@
epsilon: A small constant for numerical stability. This epsilon is
"epsilon hat" in the Kingma and Ba paper (in the formula just before
Section 2.1), not the epsilon in Algorithm 1 of the paper.
+ amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
+ the paper "On the Convergence of Adam and beyond".
name: Optional name for the operations created when applying gradients.
Defaults to "Adam". @compatibility(eager) When eager execution is
enabled, `learning_rate`, `beta_1`, `beta_2`, and `epsilon` can each be
a callable that takes no arguments and returns the actual value to use.
This can be useful for changing these values across different
invocations of optimizer functions. @end_compatibility
+ **kwargs: keyword arguments. Allowed to be {`decay`}
"""
- super(Adam, self).__init__(name)
+ super(Adam, self).__init__(name, **kwargs)
self._set_hyper('learning_rate', learning_rate)
+ self._set_hyper('decay', self._initial_decay)
self._set_hyper('beta_1', beta_1)
self._set_hyper('beta_2', beta_2)
self._set_hyper('epsilon', epsilon)
+ # TODO(tanzheny): create op for resource_apply_adam_with_amsgrad
+ if amsgrad:
+ raise ValueError('Amsgrad is currently not supported.')
+ self._amsgrad = amsgrad
def _create_slots(self, var_list):
# Create slots for the first and second moments.
@@ -110,12 +144,13 @@
self.add_slot(var, 'v')
def _resource_apply_dense(self, grad, var):
- grad_dtype = grad.dtype.base_dtype
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
m = self.get_slot(var, 'm')
v = self.get_slot(var, 'v')
- local_step = math_ops.cast(self.iterations + 1, grad_dtype)
- beta_1_t = math_ops.cast(self._get_hyper('beta_1'), grad_dtype)
- beta_2_t = math_ops.cast(self._get_hyper('beta_2'), grad_dtype)
+ beta_1_t = self._get_hyper('beta_1', var_dtype)
+ beta_2_t = self._get_hyper('beta_2', var_dtype)
+ local_step = math_ops.cast(self.iterations + 1, var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
beta_2_power = math_ops.pow(beta_2_t, local_step)
return training_ops.resource_apply_adam(
@@ -124,22 +159,22 @@
v.handle,
beta_1_power,
beta_2_power,
- math_ops.cast(self._get_hyper('learning_rate'), grad_dtype),
+ lr_t,
beta_1_t,
beta_2_t,
- math_ops.cast(self._get_hyper('epsilon'), grad_dtype),
+ self._get_hyper('epsilon', var_dtype),
grad,
use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
+ beta_1_t = self._get_hyper('beta_1', var_dtype)
+ beta_2_t = self._get_hyper('beta_2', var_dtype)
local_step = math_ops.cast(self.iterations + 1, var_dtype)
- beta_1_t = math_ops.cast(self._get_hyper('beta_1'), var_dtype)
- beta_2_t = math_ops.cast(self._get_hyper('beta_2'), var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
beta_2_power = math_ops.pow(beta_2_t, local_step)
- lr_t = math_ops.cast(self._get_hyper('learning_rate'), var_dtype)
- epsilon_t = math_ops.cast(self._get_hyper('epsilon'), var_dtype)
+ epsilon_t = self._get_hyper('epsilon', var_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
@@ -170,8 +205,10 @@
config = super(Adam, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
+ 'decay': self._serialize_hyperparameter('decay'),
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'epsilon': self._serialize_hyperparameter('epsilon'),
+ 'amsgrad': self._amsgrad,
})
return config
diff --git a/tensorflow/python/keras/optimizer_v2/adam_test.py b/tensorflow/python/keras/optimizer_v2/adam_test.py
index 46a45af..e2bc6a3 100644
--- a/tensorflow/python/keras/optimizer_v2/adam_test.py
+++ b/tensorflow/python/keras/optimizer_v2/adam_test.py
@@ -38,16 +38,16 @@
t,
m,
v,
- alpha=0.001,
+ lr=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-7):
- alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
+ lr_t = lr * np.sqrt(1 - beta2**(t + 1)) / (1 - beta1**(t + 1))
m_t = beta1 * m + (1 - beta1) * g_t
v_t = beta2 * v + (1 - beta2) * g_t * g_t
- param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ param_t = param - lr_t * m_t / (np.sqrt(v_t) + epsilon)
return param_t, m_t, v_t
@@ -90,13 +90,13 @@
self.assertAllClose([1.0, 1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 3.0, 4.0], self.evaluate(var1))
- beta1_power, beta2_power = get_beta_accumulators(opt, dtype)
-
+ beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
# Run 3 steps of Adam
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
- self.assertAllCloseAccordingToType(0.999**t,
- self.evaluate(beta2_power))
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta_1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta_2_power))
update.run()
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
@@ -177,21 +177,21 @@
epsilon = epsilon()
opt = adam.Adam(learning_rate=learning_rate)
- update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ if not context.executing_eagerly():
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
self.evaluate(variables.global_variables_initializer())
# Run 3 steps of Adam
- for t in range(1, 4):
- if not context.executing_eagerly():
- self.evaluate(update)
- elif t > 1:
- opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
-
+ for t in range(3):
beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
self.assertAllCloseAccordingToType(0.9**(t + 1),
self.evaluate(beta_1_power))
self.assertAllCloseAccordingToType(0.999**(t + 1),
self.evaluate(beta_2_power))
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ else:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
@@ -208,6 +208,52 @@
with context.eager_mode():
self.doTestBasic(use_callable_params=True)
+ def testBasicWithLearningRateDecay(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ learning_rate = 0.001
+ beta_1 = 0.9
+ beta_2 = 0.999
+ epsilon = 1e-7
+ decay = 0.5
+
+ opt = adam.Adam(
+ learning_rate=learning_rate,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon,
+ decay=decay)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.evaluate(variables.global_variables_initializer())
+ # Run 3 steps of Adam
+ for t in range(3):
+ self.evaluate(update)
+ lr_np = learning_rate / (1 + decay * t)
+
+ var0_np, m0, v0 = adam_update_numpy(
+ var0_np, grads0_np, t, m0, v0, lr=lr_np)
+ var1_np, m1, v1 = adam_update_numpy(
+ var1_np, grads1_np, t, m1, v1, lr=lr_np)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
@@ -230,13 +276,13 @@
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
- beta1_power, beta2_power = get_beta_accumulators(opt, dtype)
-
+ beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
# Run 3 steps of Adam
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
- self.assertAllCloseAccordingToType(0.999**t,
- self.evaluate(beta2_power))
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta_1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta_2_power))
update.run()
var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
@@ -265,17 +311,18 @@
update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- beta1_power, beta2_power = get_beta_accumulators(opt, dtype)
+ beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Run 3 steps of intertwined Adam1 and Adam2.
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, self.evaluate(beta1_power))
- self.assertAllCloseAccordingToType(0.999**t,
- self.evaluate(beta2_power))
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta_1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta_2_power))
if t % 2 == 0:
update1.run()
else:
@@ -296,7 +343,12 @@
opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
# There should be iteration, hyper variables, and two unique slot
# variables for v1 and v2 respectively.
- self.assertEqual(9, len(set(opt.variables())))
+ self.assertEqual(10, len(set(opt.variables())))
+
+ def testAmsgradWithError(self):
+ with self.assertRaisesRegexp(ValueError,
+ "Amsgrad is currently not supported"):
+ adam.Adam(learning_rate=1., beta_1=0.9, beta_2=0.99, amsgrad=True)
if __name__ == "__main__":
diff --git a/tensorflow/python/keras/optimizer_v2/adamax.py b/tensorflow/python/keras/optimizer_v2/adamax.py
index 6712427..ddd7858 100644
--- a/tensorflow/python/keras/optimizer_v2/adamax.py
+++ b/tensorflow/python/keras/optimizer_v2/adamax.py
@@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
-"""AdaMax for TensorFlow."""
+"""Adamax for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -27,8 +27,8 @@
from tensorflow.python.training import training_ops
-class AdaMax(adam.Adam):
- """Optimizer that implements the AdaMax algorithm.
+class Adamax(adam.Adam):
+ """Optimizer that implements the Adamax algorithm.
It is a variant of Adam based on the infinity norm.
Default parameters follow those provided in the paper.
@@ -44,8 +44,9 @@
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
- name='AdaMax'):
- """Construct a new AdaMax optimizer.
+ name='Adamax',
+ **kwargs):
+ """Construct a new Adamax optimizer.
Initialization:
@@ -86,41 +87,50 @@
rate for the exponentially weighted infinity norm.
epsilon: A small constant for numerical stability.
name: Optional name for the operations created when applying gradients.
- Defaults to "AdaMax".
+ Defaults to "Adamax".
+ **kwargs: keyword arguments. Allowed to be {`decay`}
"""
# pylint: disable=useless-super-delegation
- super(AdaMax, self).__init__(learning_rate, beta_1, beta_2, epsilon, name)
+ super(Adamax, self).__init__(
+ learning_rate=learning_rate,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon,
+ amsgrad=False,
+ name=name,
+ **kwargs)
# pylint: enable=useless-super-delegation
def _resource_apply_dense(self, grad, var):
- grad_dtype = grad.dtype.base_dtype
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
m = self.get_slot(var, 'm')
v = self.get_slot(var, 'v')
- local_step = math_ops.cast(self.iterations + 1, grad_dtype)
- beta_1_t = math_ops.cast(self._get_hyper('beta_1'), grad_dtype)
- beta_2_t = math_ops.cast(self._get_hyper('beta_2'), grad_dtype)
+ beta_1_t = self._get_hyper('beta_1', var_dtype)
+ beta_2_t = self._get_hyper('beta_2', var_dtype)
+ local_step = math_ops.cast(self.iterations + 1, var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
return training_ops.resource_apply_ada_max(
var.handle,
m.handle,
v.handle,
beta_1_power,
- math_ops.cast(self._get_hyper('learning_rate'), grad_dtype),
+ lr_t,
beta_1_t,
beta_2_t,
- math_ops.cast(self._get_hyper('epsilon'), grad_dtype),
+ self._get_hyper('epsilon', var_dtype),
grad,
use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices):
- grad_dtype = grad.dtype.base_dtype
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
- local_step = math_ops.cast(self.iterations + 1, grad_dtype)
- beta_1_t = math_ops.cast(self._get_hyper('beta_1'), grad_dtype)
- beta_2_t = math_ops.cast(self._get_hyper('beta_2'), grad_dtype)
+ beta_1_t = self._get_hyper('beta_1', var_dtype)
+ beta_2_t = self._get_hyper('beta_2', var_dtype)
+ local_step = math_ops.cast(self.iterations + 1, var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
- lr_t = math_ops.cast(self._get_hyper('learning_rate'), grad_dtype)
- epsilon_t = math_ops.cast(self._get_hyper('epsilon'), grad_dtype)
+ epsilon_t = self._get_hyper('epsilon', var_dtype)
# m_t = beta1 * m + (1 - beta1) * g_t
m = self.get_slot(var, 'm')
diff --git a/tensorflow/python/keras/optimizer_v2/adamax_test.py b/tensorflow/python/keras/optimizer_v2/adamax_test.py
index 23eb7184..aa215b0 100644
--- a/tensorflow/python/keras/optimizer_v2/adamax_test.py
+++ b/tensorflow/python/keras/optimizer_v2/adamax_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for AdaMax."""
+"""Tests for Adamax."""
from __future__ import absolute_import
from __future__ import division
@@ -44,7 +44,7 @@
epsilon=1e-8):
m_t = beta1 * m + (1 - beta1) * g_t
v_t = np.maximum(beta2 * v, np.abs(g_t))
- param_t = param - (alpha / (1 - beta1**t)) * (m_t / (v_t + epsilon))
+ param_t = param - (alpha / (1 - beta1**(t + 1))) * (m_t / (v_t + epsilon))
return param_t, m_t, v_t
@@ -61,8 +61,8 @@
m_t, v_t, param_t = np.copy(m), np.copy(v), np.copy(param)
m_t_slice = beta1 * m[indices] + (1 - beta1) * g_t
v_t_slice = np.maximum(beta2 * v[indices], np.abs(g_t))
- param_t_slice = param[indices] - ((alpha / (1 - beta1**t)) *
- (m_t_slice / (v_t_slice + epsilon)))
+ param_t_slice = param[indices] - (
+ (alpha / (1 - beta1**(t + 1))) * (m_t_slice / (v_t_slice + epsilon)))
m_t[indices] = m_t_slice
v_t[indices] = v_t_slice
param_t[indices] = param_t_slice
@@ -76,7 +76,7 @@
return beta_1_power
-class AdaMaxOptimizerTest(test.TestCase):
+class AdamaxOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
@@ -100,7 +100,7 @@
grads1 = ops.IndexedSlices(
constant_op.constant(grads1_np),
constant_op.constant(grads1_np_indices), constant_op.constant([3]))
- opt = adamax.AdaMax()
+ opt = adamax.Adamax()
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
@@ -110,9 +110,9 @@
beta1_power = get_beta_accumulators(opt, dtype)
- # Run 3 steps of AdaMax
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ # Run 3 steps of Adamax
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
update.run()
var0_np, m0, v0 = adamax_sparse_update_numpy(
@@ -135,7 +135,7 @@
var = variables.Variable([[1.0], [2.0]])
indices = constant_op.constant([0, 1], dtype=index_dtype)
gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
- optimizer = adamax.AdaMax(3.0)
+ optimizer = adamax.Adamax(3.0)
minimize_op = optimizer.minimize(gathered_sum, var_list=[var])
variables.global_variables_initializer().run()
minimize_op.run()
@@ -157,9 +157,9 @@
[0.2], shape=[1, 1], dtype=dtype),
constant_op.constant([1]),
constant_op.constant([2, 1]))
- repeated_update = adamax.AdaMax().apply_gradients(
+ repeated_update = adamax.Adamax().apply_gradients(
[(grad_repeated_index, repeated_index_update_var)])
- aggregated_update = adamax.AdaMax().apply_gradients(
+ aggregated_update = adamax.Adamax().apply_gradients(
[(grad_aggregated, aggregated_update_var)])
variables.global_variables_initializer().run()
self.assertAllClose(aggregated_update_var.eval(),
@@ -189,8 +189,9 @@
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
- opt = adamax.AdaMax()
- update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt = adamax.Adamax()
+ if not context.executing_eagerly():
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
@@ -198,21 +199,74 @@
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
- # Run 3 steps of AdaMax
- for t in range(1, 4):
- if not context.executing_eagerly():
- self.evaluate(update)
- elif t > 1:
- opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
-
+ # Run 3 steps of Adamax
+ for t in range(3):
beta_1_power = get_beta_accumulators(opt, dtype)
self.assertAllCloseAccordingToType(0.9**(t + 1),
self.evaluate(beta_1_power))
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ else:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1)
# Validate updated params
+ self.assertAllCloseAccordingToType(
+ var0_np, self.evaluate(var0), rtol=1e-2)
+ self.assertAllCloseAccordingToType(
+ var1_np, self.evaluate(var1), rtol=1e-2)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testBasicWithLearningRateDecay(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ learning_rate = 0.001
+ decay = 0.002
+ opt = adamax.Adamax(learning_rate=learning_rate, decay=decay)
+ if not context.executing_eagerly():
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Run 3 steps of Adamax
+ for t in range(3):
+ beta_1_power = get_beta_accumulators(opt, dtype)
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta_1_power))
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ else:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ lr = learning_rate / (1 + decay * t)
+
+ var0_np, m0, v0 = adamax_update_numpy(
+ var0_np, grads0_np, t, m0, v0, alpha=lr)
+ var1_np, m1, v1 = adamax_update_numpy(
+ var1_np, grads1_np, t, m1, v1, alpha=lr)
+
+ # Validate updated params
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0),
rtol=1e-2)
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1),
@@ -232,7 +286,7 @@
var1 = variables.Variable(var1_np)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
- opt = adamax.AdaMax(constant_op.constant(0.001))
+ opt = adamax.Adamax(constant_op.constant(0.001))
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
@@ -242,9 +296,9 @@
beta1_power = get_beta_accumulators(opt, dtype)
- # Run 3 steps of AdaMax
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ # Run 3 steps of Adamax
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
update.run()
var0_np, m0, v0 = adamax_update_numpy(var0_np, grads0_np, t, m0, v0)
@@ -268,7 +322,7 @@
var1 = variables.Variable(var1_np)
grads0 = constant_op.constant(grads0_np)
grads1 = constant_op.constant(grads1_np)
- opt = adamax.AdaMax()
+ opt = adamax.Adamax()
update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
@@ -279,9 +333,9 @@
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
- # Run 3 steps of intertwined AdaMax1 and AdaMax2.
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ # Run 3 steps of intertwined Adamax1 and Adamax2.
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
if t % 2 == 0:
update1.run()
else:
@@ -298,11 +352,11 @@
with context.eager_mode():
v1 = resource_variable_ops.ResourceVariable(1.)
v2 = resource_variable_ops.ResourceVariable(1.)
- opt = adamax.AdaMax(1.)
+ opt = adamax.Adamax(1.)
opt.minimize(lambda: v1 + v2, var_list=[v1, v2])
# There should be iteration, hyper variables, and two unique slot
# variables for v1 and v2 respectively.
- self.assertEqual(9, len(set(opt.variables())))
+ self.assertEqual(10, len(set(opt.variables())))
if __name__ == "__main__":
diff --git a/tensorflow/python/keras/optimizer_v2/ftrl.py b/tensorflow/python/keras/optimizer_v2/ftrl.py
new file mode 100644
index 0000000..e278e35
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/ftrl.py
@@ -0,0 +1,210 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ftrl-proximal for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import training_ops
+
+
+class Ftrl(optimizer_v2.OptimizerV2):
+ """Optimizer that implements the FTRL algorithm.
+
+ See this [paper](
+ https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
+ This version has support for both online L2 (the L2 penalty given in the paper
+ above) and shrinkage-type L2 (which is the addition of an L2 penalty to the
+ loss function).
+ """
+
+ def __init__(self,
+ learning_rate,
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0,
+ name='Ftrl',
+ l2_shrinkage_regularization_strength=0.0,
+ **kwargs):
+ r"""Construct a new FTRL optimizer.
+
+ Args:
+ learning_rate: A float value or a constant float `Tensor`.
+ learning_rate_power: A float value, must be less or equal to zero.
+ Controls how the learning rate decreases during training. Use zero for
+ a fixed learning rate.
+ initial_accumulator_value: The starting value for accumulators.
+ Only zero or positive values are allowed.
+ l1_regularization_strength: A float value, must be greater than or
+ equal to zero.
+ l2_regularization_strength: A float value, must be greater than or
+ equal to zero.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Ftrl".
+ l2_shrinkage_regularization_strength: A float value, must be greater than
+ or equal to zero. This differs from L2 above in that the L2 above is a
+ stabilization penalty, whereas this L2 shrinkage is a magnitude penalty.
+ The FTRL formulation can be written as:
+ w_{t+1} = argmin_w(\hat{g}_{1:t}w + L1*||w||_1 + L2*||w||_2^2), where
+ \hat{g} = g + (2*L2_shrinkage*w), and g is the gradient of the loss
+ function w.r.t. the weights w.
+ Specifically, in the absence of L1 regularization, it is equivalent to
+ the following update rule:
+ w_{t+1} = w_t - lr_t / (1 + 2*L2*lr_t) * g_t -
+ 2*L2_shrinkage*lr_t / (1 + 2*L2*lr_t) * w_t
+ where lr_t is the learning rate at t.
+ When input is sparse shrinkage will only happen on the active weights.\
+ **kwargs: keyword arguments. Allowed to be {`decay`}
+
+ Raises:
+ ValueError: If one of the arguments is invalid.
+
+ References
+ See [paper]
+ (https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf)
+ """
+ super(Ftrl, self).__init__(name, **kwargs)
+
+ if initial_accumulator_value < 0.0:
+ raise ValueError(
+ 'initial_accumulator_value %f needs to be positive or zero' %
+ initial_accumulator_value)
+ if learning_rate_power > 0.0:
+ raise ValueError('learning_rate_power %f needs to be negative or zero' %
+ learning_rate_power)
+ if l1_regularization_strength < 0.0:
+ raise ValueError(
+ 'l1_regularization_strength %f needs to be positive or zero' %
+ l1_regularization_strength)
+ if l2_regularization_strength < 0.0:
+ raise ValueError(
+ 'l2_regularization_strength %f needs to be positive or zero' %
+ l2_regularization_strength)
+ if l2_shrinkage_regularization_strength < 0.0:
+ raise ValueError(
+ 'l2_shrinkage_regularization_strength %f needs to be positive'
+ ' or zero' % l2_shrinkage_regularization_strength)
+
+ self._set_hyper('learning_rate', learning_rate)
+ self._set_hyper('decay', self._initial_decay)
+ self._set_hyper('learning_rate_power', learning_rate_power)
+ self._set_hyper('l1_regularization_strength', l1_regularization_strength)
+ self._set_hyper('l2_regularization_strength', l2_regularization_strength)
+ self._initial_accumulator_value = initial_accumulator_value
+ self._l2_shrinkage_regularization_strength = (
+ l2_shrinkage_regularization_strength)
+
+ def _create_slots(self, var_list):
+ # Create the "accum" and "linear" slots.
+ for var in var_list:
+ dtype = var.dtype.base_dtype
+ init = init_ops.constant_initializer(
+ self._initial_accumulator_value, dtype=dtype)
+ self.add_slot(var, 'accumulator', init)
+ self.add_slot(var, 'linear')
+
+ def _resource_apply_dense(self, grad, var):
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
+ learning_rate_power = self._get_hyper('learning_rate_power', var_dtype)
+ l1_regularization_strength = self._get_hyper('l1_regularization_strength',
+ var_dtype)
+ l2_regularization_strength = self._get_hyper('l2_regularization_strength',
+ var_dtype)
+ accum = self.get_slot(var, 'accumulator')
+ linear = self.get_slot(var, 'linear')
+ if self._l2_shrinkage_regularization_strength <= 0.0:
+ return training_ops.resource_apply_ftrl(
+ var.handle,
+ accum.handle,
+ linear.handle,
+ grad,
+ lr_t,
+ l1_regularization_strength,
+ l2_regularization_strength,
+ learning_rate_power,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_apply_ftrl_v2(
+ var.handle,
+ accum.handle,
+ linear.handle,
+ grad,
+ lr_t,
+ l1_regularization_strength,
+ l2_regularization_strength,
+ math_ops.cast(self._l2_shrinkage_regularization_strength, var_dtype),
+ learning_rate_power,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
+ learning_rate_power = self._get_hyper('learning_rate_power', var_dtype)
+ l1_regularization_strength = self._get_hyper('l1_regularization_strength',
+ var_dtype)
+ l2_regularization_strength = self._get_hyper('l2_regularization_strength',
+ var_dtype)
+ accum = self.get_slot(var, 'accumulator')
+ linear = self.get_slot(var, 'linear')
+ if self._l2_shrinkage_regularization_strength <= 0.0:
+ return training_ops.resource_sparse_apply_ftrl(
+ var.handle,
+ accum.handle,
+ linear.handle,
+ grad,
+ indices,
+ lr_t,
+ l1_regularization_strength,
+ l2_regularization_strength,
+ learning_rate_power,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_sparse_apply_ftrl_v2(
+ var.handle,
+ accum.handle,
+ linear.handle,
+ grad,
+ indices,
+ lr_t,
+ l1_regularization_strength,
+ l2_regularization_strength,
+ math_ops.cast(self._l2_shrinkage_regularization_strength, var_dtype),
+ learning_rate_power,
+ use_locking=self._use_locking)
+
+ def get_config(self):
+ config = super(Ftrl, self).get_config()
+ config.update({
+ 'learning_rate':
+ self._serialize_hyperparameter('learning_rate'),
+ 'decay':
+ self._serialize_hyperparameter('decay'),
+ 'initial_accumulator_value':
+ self._initial_accumulator_value,
+ 'learning_rate_power':
+ self._serialize_hyperparameter('learning_rate_power'),
+ 'l1_regularization_strength':
+ self._serializer_hyperparameter('l1_regularization_strength'),
+ 'l2_regularization_strength':
+ self._serializer_hyperparameter('l2_regularization_strength'),
+ 'l2_shrinkage_regularization_strength':
+ self._l2_shrinkage_regularization_strength,
+ })
+ return config
diff --git a/tensorflow/python/keras/optimizer_v2/ftrl_test.py b/tensorflow/python/keras/optimizer_v2/ftrl_test.py
new file mode 100644
index 0000000..ca8c33d
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/ftrl_test.py
@@ -0,0 +1,426 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for Ftrl operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import ftrl
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import adagrad
+from tensorflow.python.training import gradient_descent
+
+
+class FtrlOptimizerTest(test.TestCase):
+
+ def doTestFtrlwithoutRegularization(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session() as sess:
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
+ else:
+ var0 = variables.Variable([0.0, 0.0], dtype=dtype)
+ var1 = variables.Variable([0.0, 0.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+ opt = ftrl.Ftrl(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllClose([0.0, 0.0], v0_val)
+ self.assertAllClose([0.0, 0.0], v1_val)
+
+ # Run 3 steps FTRL
+ for _ in range(3):
+ update.run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-2.60260963, -4.29698515]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([-0.28432083, -0.56694895]), v1_val)
+
+ def testFtrlWithoutRegularization(self):
+ self.doTestFtrlwithoutRegularization(use_resource=False)
+
+ def testResourceFtrlWithoutRegularization(self):
+ self.doTestFtrlwithoutRegularization(use_resource=True)
+
+ def testFtrlwithoutRegularization2(self):
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session() as sess:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = ftrl.Ftrl(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
+
+ # Run 3 steps FTRL
+ for _ in range(3):
+ update.run()
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-2.55607247, -3.98729396]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([-0.28232238, -0.56096673]), v1_val)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = ftrl.Ftrl(1.0).minimize(loss, var_list=[var0])
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0))
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[0, 1]],
+ self.evaluate(var0),
+ atol=0.01)
+
+ def testFtrlWithL1(self):
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session() as sess:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = ftrl.Ftrl(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-7.66718769, -10.91273689]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([-0.93460727, -1.86147261]), v1_val)
+
+ def testFtrlWithL1_L2(self):
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session() as sess:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = ftrl.Ftrl(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-0.24059935, -0.46829352]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([-0.02406147, -0.04830509]), v1_val)
+
+ def testFtrlWithL1_L2_L2Shrinkage(self):
+ """Test the new FTRL op with support for l2 shrinkage.
+
+ The addition of this parameter which places a constant pressure on weights
+ towards the origin causes the gradient descent trajectory to differ. The
+ weights will tend to have smaller magnitudes with this parameter set.
+ """
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session() as sess:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([4.0, 3.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ opt = ftrl.Ftrl(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType(
+ np.array([-0.22578995, -0.44345796]), v0_val)
+ self.assertAllCloseAccordingToType(
+ np.array([-0.14378493, -0.13229476]), v1_val)
+
+ def testFtrlWithL1_L2_L2ShrinkageSparse(self):
+ """Tests the new FTRL op with support for l2 shrinkage on sparse grads."""
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session() as sess:
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[4.0], [3.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+
+ opt = ftrl.Ftrl(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([[1.0], [2.0]], v0_val)
+ self.assertAllCloseAccordingToType([[4.0], [3.0]], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([[-0.22578995], [2.]], v0_val)
+ self.assertAllCloseAccordingToType([[4.], [-0.13229476]], v1_val)
+
+ def testFtrlWithL2ShrinkageDoesNotChangeLrSchedule(self):
+ """Verifies that l2 shrinkage in FTRL does not change lr schedule."""
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session() as sess:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([1.0, 2.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.1, 0.2], dtype=dtype)
+
+ opt0 = ftrl.Ftrl(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0,
+ l2_shrinkage_regularization_strength=0.1)
+ opt1 = ftrl.Ftrl(
+ 3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update0 = opt0.apply_gradients([(grads0, var0)])
+ update1 = opt1.apply_gradients([(grads1, var1)])
+ variables.global_variables_initializer().run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
+ self.assertAllCloseAccordingToType([1.0, 2.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update0.run()
+ update1.run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ # var0 is experiencing L2 shrinkage so it should be smaller than var1
+ # in magnitude.
+ self.assertTrue((v0_val**2 < v1_val**2).all())
+ accum0 = sess.run(opt0.get_slot(var0, "accumulator"))
+ accum1 = sess.run(opt1.get_slot(var1, "accumulator"))
+ # L2 shrinkage should not change how we update grad accumulator.
+ self.assertAllCloseAccordingToType(accum0, accum1)
+
+ def applyOptimizer(self, opt, dtype, steps=5, is_sparse=False):
+ if is_sparse:
+ var0 = variables.Variable([[0.0], [0.0]], dtype=dtype)
+ var1 = variables.Variable([[0.0], [0.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.02], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ else:
+ var0 = variables.Variable([0.0, 0.0], dtype=dtype)
+ var1 = variables.Variable([0.0, 0.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.2], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.02], dtype=dtype)
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ sess = ops.get_default_session()
+ v0_val, v1_val = self.evaluate([var0, var1])
+ if is_sparse:
+ self.assertAllCloseAccordingToType([[0.0], [0.0]], v0_val)
+ self.assertAllCloseAccordingToType([[0.0], [0.0]], v1_val)
+ else:
+ self.assertAllCloseAccordingToType([0.0, 0.0], v0_val)
+ self.assertAllCloseAccordingToType([0.0, 0.0], v1_val)
+
+ # Run Ftrl for a few steps
+ for _ in range(steps):
+ update.run()
+
+ v0_val, v1_val = self.evaluate([var0, var1])
+ return v0_val, v1_val
+
+ # When variables are initialized with Zero, FTRL-Proximal has two properties:
+ # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical
+ # with GradientDescent.
+ # 2. Without L1&L2 but with adaptive learning rate, FTRL-Proximal is identical
+ # with Adagrad.
+ # So, basing on these two properties, we test if our implementation of
+ # FTRL-Proximal performs same updates as Adagrad or GradientDescent.
+ def testEquivAdagradwithoutRegularization(self):
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session():
+ val0, val1 = self.applyOptimizer(
+ ftrl.Ftrl(
+ 3.0,
+ # Adagrad learning rate
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0),
+ dtype)
+
+ with self.cached_session():
+ val2, val3 = self.applyOptimizer(
+ adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1), dtype)
+
+ self.assertAllCloseAccordingToType(val0, val2)
+ self.assertAllCloseAccordingToType(val1, val3)
+
+ def testEquivSparseAdagradwithoutRegularization(self):
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session():
+ val0, val1 = self.applyOptimizer(
+ ftrl.Ftrl(
+ 3.0,
+ # Adagrad learning rate
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0),
+ dtype,
+ is_sparse=True)
+
+ with self.cached_session():
+ val2, val3 = self.applyOptimizer(
+ adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
+ dtype,
+ is_sparse=True)
+
+ self.assertAllCloseAccordingToType(val0, val2)
+ self.assertAllCloseAccordingToType(val1, val3)
+
+ def testEquivSparseGradientDescentwithoutRegularization(self):
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session():
+ val0, val1 = self.applyOptimizer(
+ ftrl.Ftrl(
+ 3.0,
+ # Fixed learning rate
+ learning_rate_power=-0.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0),
+ dtype,
+ is_sparse=True)
+
+ with self.cached_session():
+ val2, val3 = self.applyOptimizer(
+ gradient_descent.GradientDescentOptimizer(3.0),
+ dtype,
+ is_sparse=True)
+
+ self.assertAllCloseAccordingToType(val0, val2)
+ self.assertAllCloseAccordingToType(val1, val3)
+
+ def testEquivGradientDescentwithoutRegularization(self):
+ for dtype in [dtypes.half, dtypes.float32]:
+ with self.cached_session():
+ val0, val1 = self.applyOptimizer(
+ ftrl.Ftrl(
+ 3.0,
+ # Fixed learning rate
+ learning_rate_power=-0.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0),
+ dtype)
+
+ with self.cached_session():
+ val2, val3 = self.applyOptimizer(
+ gradient_descent.GradientDescentOptimizer(3.0), dtype)
+
+ self.assertAllCloseAccordingToType(val0, val2)
+ self.assertAllCloseAccordingToType(val1, val3)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent.py b/tensorflow/python/keras/optimizer_v2/gradient_descent.py
index 90106c9..03e4515 100644
--- a/tensorflow/python/keras/optimizer_v2/gradient_descent.py
+++ b/tensorflow/python/keras/optimizer_v2/gradient_descent.py
@@ -19,7 +19,6 @@
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import training_ops
@@ -62,7 +61,8 @@
learning_rate=0.001,
momentum=0.0,
nesterov=False,
- name="SGD"):
+ name="SGD",
+ **kwargs):
"""Construct a new Stochastic Gradient Descent or Momentum optimizer.
Arguments:
@@ -72,9 +72,11 @@
nesterov: boolean. Whether to apply Nesterov momentum.
name: Optional name prefix for the operations created when applying
gradients. Defaults to 'SGD'.
+ **kwargs: keyword arguments. Allowed to be {`decay`}
"""
- super(SGD, self).__init__(name)
+ super(SGD, self).__init__(name, **kwargs)
self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("decay", self._initial_decay)
self._momentum = False
if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
@@ -91,44 +93,44 @@
self.add_slot(var, "momentum")
def _resource_apply_dense(self, grad, var):
- learning_rate = self._get_hyper("learning_rate")
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
if self._momentum:
momentum_var = self.get_slot(var, "momentum")
return training_ops.resource_apply_momentum(
var.handle,
momentum_var.handle,
- math_ops.cast(learning_rate, grad.dtype.base_dtype),
+ lr_t,
grad,
- math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype),
+ self._get_hyper("momentum", var_dtype),
use_locking=self._use_locking,
use_nesterov=self._nesterov)
else:
return training_ops.resource_apply_gradient_descent(
- var.handle,
- math_ops.cast(learning_rate, grad.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
+ var.handle, lr_t, grad, use_locking=self._use_locking)
def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
if self._momentum:
return super(SGD, self)._resource_apply_sparse_duplicate_indices(
grad, var, indices)
else:
- return resource_variable_ops.resource_scatter_add(
- var.handle, indices, -grad * math_ops.cast(
- self._get_hyper("learning_rate"), grad.dtype.base_dtype))
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
+ return resource_variable_ops.resource_scatter_add(var.handle, indices,
+ -grad * lr_t)
def _resource_apply_sparse(self, grad, var, indices):
# This method is only needed for momentum optimization.
- learning_rate = self._get_hyper("learning_rate")
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
momentum_var = self.get_slot(var, "momentum")
return training_ops.resource_sparse_apply_momentum(
var.handle,
momentum_var.handle,
- math_ops.cast(learning_rate, grad.dtype.base_dtype),
+ lr_t,
grad,
indices,
- math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype),
+ self._get_hyper("momentum", var_dtype),
use_locking=self._use_locking,
use_nesterov=self._nesterov)
@@ -136,6 +138,7 @@
config = super(SGD, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
+ "decay": self._serialize_hyperparameter("decay"),
"momentum": self._serialize_hyperparameter("momentum"),
"nesterov": self._nesterov,
})
diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
index fa7cca1..348d272 100644
--- a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
+++ b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py
@@ -47,7 +47,6 @@
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
sgd = gradient_descent.SGD(3.0)
- # self.assertFalse(sgd._initial_decay)
sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
self.evaluate(variables.global_variables_initializer())
# Run 1 step of sgd
@@ -59,6 +58,43 @@
self.evaluate(var1))
@test_util.run_in_graph_and_eager_modes
+ def testBasicWithLearningRateDecay(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ learning_rate = 3.0
+ decay = 0.5
+ sgd = gradient_descent.SGD(learning_rate=learning_rate, decay=decay)
+ if not context.executing_eagerly():
+ sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+ # Run 2 steps of sgd
+ if not context.executing_eagerly():
+ self.evaluate(sgd_op)
+ else:
+ sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ self.evaluate(var0))
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ self.evaluate(var1))
+
+ if not context.executing_eagerly():
+ self.evaluate(sgd_op)
+ else:
+ sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [1.0 - 3.0 * 0.1 - 2.0 * 0.1, 2.0 - 3.0 * 0.1 - 2.0 * 0.1],
+ self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ [3.0 - 3.0 * 0.01 - 2.0 * 0.01, 4.0 - 3.0 * 0.01 - 2.0 * 0.01],
+ self.evaluate(var1))
+
+ @test_util.run_in_graph_and_eager_modes
def testBasicCallableParams(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
@@ -170,6 +206,36 @@
self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
self.evaluate(var1))
+ def testSparseBasicWithLearningRateDecay(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ sgd_op = gradient_descent.SGD(
+ 3.0, decay=0.5).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+ # Run 2 steps of sgd
+ self.evaluate(sgd_op)
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[1.0 - 3.0 * 0.1], [2.0]],
+ self.evaluate(var0))
+ self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
+ self.evaluate(var1))
+
+ self.evaluate(sgd_op)
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[1.0 - 3.0 * 0.1 - 2.0 * 0.1], [2.0]], self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ [[3.0], [4.0 - 3.0 * 0.01 - 2.0 * 0.01]], self.evaluate(var1))
+
def testCapturingInDefunWhileExecutingEagerly(self):
with context.eager_mode():
optimizer = gradient_descent.SGD(1.0)
diff --git a/tensorflow/python/keras/optimizer_v2/nadam.py b/tensorflow/python/keras/optimizer_v2/nadam.py
index 4be421a..00b095e 100644
--- a/tensorflow/python/keras/optimizer_v2/nadam.py
+++ b/tensorflow/python/keras/optimizer_v2/nadam.py
@@ -53,13 +53,46 @@
See [Dozat, T., 2015](http://cs229.stanford.edu/proj2015/054_report.pdf).
"""
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-7,
+ name='Nadam',
+ **kwargs):
+ """Construct a new Nadam optimizer.
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ beta_1: A float value or a constant float tensor. The exponential decay
+ rate for the 1st moment estimates.
+ beta_2: A float value or a constant float tensor. The exponential decay
+ rate for the exponentially weighted infinity norm.
+ epsilon: A small constant for numerical stability.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adamax".
+ **kwargs: keyword arguments. Allowed to be {`decay`}
+ """
+
+ # pylint: disable=useless-super-delegation
+ super(Nadam, self).__init__(
+ learning_rate=learning_rate,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon,
+ amsgrad=False,
+ name=name,
+ **kwargs)
+ # pylint: enable=useless-super-delegation
+
def _resource_apply_dense(self, grad, var):
- grad_dtype = grad.dtype.base_dtype
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
m = self.get_slot(var, 'm')
v = self.get_slot(var, 'v')
- local_step = math_ops.cast(self.iterations + 1, grad_dtype)
- beta_1_t = math_ops.cast(self._get_hyper('beta_1'), grad_dtype)
- beta_2_t = math_ops.cast(self._get_hyper('beta_2'), grad_dtype)
+ beta_1_t = self._get_hyper('beta_1', var_dtype)
+ beta_2_t = self._get_hyper('beta_2', var_dtype)
+ local_step = math_ops.cast(self.iterations + 1, var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
beta_2_power = math_ops.pow(beta_2_t, local_step)
return training_ops.resource_apply_adam(
@@ -68,23 +101,23 @@
v.handle,
beta_1_power,
beta_2_power,
- math_ops.cast(self._get_hyper('learning_rate'), grad_dtype),
+ lr_t,
beta_1_t,
beta_2_t,
- math_ops.cast(self._get_hyper('epsilon'), grad_dtype),
+ self._get_hyper('epsilon', var_dtype),
grad,
use_locking=self._use_locking,
use_nesterov=True)
def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
+ beta_1_t = self._get_hyper('beta_1', var_dtype)
+ beta_2_t = self._get_hyper('beta_2', var_dtype)
local_step = math_ops.cast(self.iterations + 1, var_dtype)
- beta_1_t = math_ops.cast(self._get_hyper('beta_1'), var_dtype)
- beta_2_t = math_ops.cast(self._get_hyper('beta_2'), var_dtype)
beta_1_power = math_ops.pow(beta_1_t, local_step)
beta_2_power = math_ops.pow(beta_2_t, local_step)
- lr_t = math_ops.cast(self._get_hyper('learning_rate'), var_dtype)
- epsilon_t = math_ops.cast(self._get_hyper('epsilon'), var_dtype)
+ epsilon_t = self._get_hyper('epsilon', var_dtype)
lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
# m_t = beta1 * m + (1 - beta1) * g_t
diff --git a/tensorflow/python/keras/optimizer_v2/nadam_test.py b/tensorflow/python/keras/optimizer_v2/nadam_test.py
index 9cc81b1..b7132bb 100644
--- a/tensorflow/python/keras/optimizer_v2/nadam_test.py
+++ b/tensorflow/python/keras/optimizer_v2/nadam_test.py
@@ -48,7 +48,7 @@
beta1=0.9,
beta2=0.999,
epsilon=1e-8):
- alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
+ alpha_t = alpha * np.sqrt(1 - beta2**(t + 1)) / (1 - beta1**(t + 1))
m_t = beta1 * m + (1 - beta1) * g_t
v_t = beta2 * v + (1 - beta2) * g_t * g_t
@@ -97,9 +97,9 @@
beta1_power, beta2_power = get_beta_accumulators(opt, dtype)
# Run 3 steps of Nadam
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
- self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**(t + 1), beta2_power.eval())
update.run()
var0_np, m0, v0 = nadam_update_numpy(
@@ -146,9 +146,9 @@
beta1_power, beta2_power = get_beta_accumulators(opt, dtype)
# Run 3 steps of Nadam
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
- self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**(t + 1), beta2_power.eval())
update.run()
var0_np, m0, v0 = nadam_update_numpy(var0_np, grads0_np, t, m0, v0)
@@ -158,12 +158,51 @@
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
- def testBasic(self):
- self.doTestBasic(use_resource=False)
-
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
+ def testBasicWithLearningRateDecay(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ learning_rate = 0.001
+ decay = 0.5
+ opt = nadam.Nadam(learning_rate=learning_rate, decay=decay)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = get_beta_accumulators(opt, dtype)
+
+ # Run 3 steps of Nadam
+ for t in range(3):
+ self.assertAllCloseAccordingToType(0.9**(t + 1), beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**(t + 1), beta2_power.eval())
+ update.run()
+
+ lr = learning_rate / (1 + decay * t)
+ var0_np, m0, v0 = nadam_update_numpy(
+ var0_np, grads0_np, t, m0, v0, alpha=lr)
+ var1_np, m1, v1 = nadam_update_numpy(
+ var1_np, grads1_np, t, m1, v1, alpha=lr)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
index fa7cfa5..0101ea8 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -31,8 +31,9 @@
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend
from tensorflow.python.keras import initializers
-from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribution_strategy_context
@@ -114,7 +115,7 @@
"""
- def __init__(self, name):
+ def __init__(self, name, **kwargs):
"""Create a new Optimizer.
This must be called by the constructors of subclasses.
@@ -128,6 +129,7 @@
Args:
name: A non-empty string. The name to use for accumulators created
for the optimizer.
+ **kwargs: keyword arguments. Allowed to be {`decay`}
Raises:
ValueError: If name is malformed.
@@ -140,6 +142,12 @@
# dict: {variable name : {slot name : variable}}
self._slots = {}
self._weights = []
+
+ decay = kwargs.pop("decay", 0.0)
+ if decay < 0.:
+ raise ValueError("decay cannot be less than 0: {}".format(decay))
+ self._initial_decay = decay
+
self._prepared = False
def minimize(self,
@@ -345,9 +353,14 @@
else:
backend.set_value(self._hyper[name], value)
- def _get_hyper(self, name):
+ def _get_hyper(self, name, dtype=None):
value = self._hyper[name]
- return self._call_if_callable(value)
+ if callable(value):
+ value = value()
+ if dtype:
+ return math_ops.cast(value, dtype)
+ else:
+ return value
def __getattribute__(self, name):
"""Overridden to support hyperparameter access."""
@@ -422,6 +435,15 @@
self._prepare()
return self._iterations
+ def _decayed_lr(self, var_dtype):
+ """Get decayed learning rate as a Tensor with dtype=var_dtype."""
+ lr_t = self._get_hyper("learning_rate", var_dtype)
+ if self._initial_decay > 0.:
+ local_step = math_ops.cast(self.iterations, var_dtype)
+ decay_t = self._get_hyper("decay", var_dtype)
+ lr_t = lr_t / (1. + decay_t * local_step)
+ return lr_t
+
@abc.abstractmethod
def get_config(self):
"""Returns the config of the optimimizer.
@@ -528,7 +550,7 @@
variable = self._add_variable_with_custom_getter(
name=name,
shape=shape,
- getter=base_layer.make_variable,
+ getter=base_layer_utils.make_variable,
overwrite=True,
initializer=initializer,
dtype=dtype,
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py
index e34397c..6a5b334 100644
--- a/tensorflow/python/keras/optimizer_v2/rmsprop.py
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py
@@ -12,19 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""RMSProp for TensorFlow."""
+"""RMSprop for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import math_ops
from tensorflow.python.training import training_ops
-class RMSProp(optimizer_v2.OptimizerV2):
- r"""Optimizer that implements the RMSProp algorithm.
+class RMSprop(optimizer_v2.OptimizerV2):
+ r"""Optimizer that implements the RMSprop algorithm.
A detailed description of rmsprop.
@@ -36,7 +35,7 @@
mean_square_t + \epsilon}$$
$$variable_t := variable_{t-1} - mom_t
- This implementation of RMSProp uses plain momentum, not Nesterov momentum.
+ This implementation of RMSprop uses plain momentum, not Nesterov momentum.
The centered version additionally maintains a moving average of the
gradients, and uses that average to estimate the variance:
@@ -58,8 +57,9 @@
momentum=0.0,
epsilon=1e-7,
centered=False,
- name="RMSProp"):
- """Construct a new RMSProp optimizer.
+ name="RMSprop",
+ **kwargs):
+ """Construct a new RMSprop optimizer.
Note that in the dense implementation of this algorithm, variables and their
corresponding accumulators (momentum, gradient moving average, square
@@ -83,17 +83,16 @@
True may help with training, but is slightly more expensive in terms of
computation and memory. Defaults to False.
name: Optional name prefix for the operations created when applying
- gradients. Defaults to "RMSProp".
-
- @compatibility(eager)
- When eager execution is enabled, `learning_rate`, `decay`, `momentum`, and
- `epsilon` can each be a callable that takes no arguments and returns the
- actual value to use. This can be useful for changing these values across
- different invocations of optimizer functions.
- @end_compatibility
+ gradients. Defaults to "RMSprop". @compatibility(eager) When eager
+ execution is enabled, `learning_rate`, `decay`, `momentum`, and
+ `epsilon` can each be a callable that takes no arguments and returns the
+ actual value to use. This can be useful for changing these values across
+ different invocations of optimizer functions. @end_compatibility
+ **kwargs: keyword arguments. Allowed to be {`decay`}
"""
- super(RMSProp, self).__init__(name)
+ super(RMSprop, self).__init__(name, **kwargs)
self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("decay", self._initial_decay)
self._set_hyper("rho", rho)
self._momentum = False
@@ -114,13 +113,13 @@
self.add_slot(var, "mg")
def _resource_apply_dense(self, grad, var):
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
rms = self.get_slot(var, "rms")
mom = self.get_slot(var, "momentum")
- learning_rate = math_ops.cast(
- self._get_hyper("learning_rate"), grad.dtype.base_dtype)
- rho = math_ops.cast(self._get_hyper("rho"), grad.dtype.base_dtype)
- momentum = math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype)
- epsilon = math_ops.cast(self._get_hyper("epsilon"), grad.dtype.base_dtype)
+ rho = self._get_hyper("rho", var_dtype)
+ momentum = self._get_hyper("momentum", var_dtype)
+ epsilon = self._get_hyper("epsilon", var_dtype)
if self._centered:
mg = self.get_slot(var, "mg")
return training_ops.resource_apply_centered_rms_prop(
@@ -128,7 +127,7 @@
mg.handle,
rms.handle,
mom.handle,
- learning_rate,
+ lr_t,
rho,
momentum,
epsilon,
@@ -139,7 +138,7 @@
var.handle,
rms.handle,
mom.handle,
- learning_rate,
+ lr_t,
rho,
momentum,
epsilon,
@@ -147,13 +146,13 @@
use_locking=self._use_locking)
def _resource_apply_sparse(self, grad, var, indices):
+ var_dtype = var.dtype.base_dtype
+ lr_t = self._decayed_lr(var_dtype)
rms = self.get_slot(var, "rms")
mom = self.get_slot(var, "momentum")
- learning_rate = math_ops.cast(
- self._get_hyper("learning_rate"), grad.dtype.base_dtype)
- rho = math_ops.cast(self._get_hyper("rho"), grad.dtype.base_dtype)
- momentum = math_ops.cast(self._get_hyper("momentum"), grad.dtype.base_dtype)
- epsilon = math_ops.cast(self._get_hyper("epsilon"), grad.dtype.base_dtype)
+ rho = self._get_hyper("rho", var_dtype)
+ momentum = self._get_hyper("momentum", var_dtype)
+ epsilon = self._get_hyper("epsilon", var_dtype)
if self._centered:
mg = self.get_slot(var, "mg")
return training_ops.resource_sparse_apply_centered_rms_prop(
@@ -161,7 +160,7 @@
mg.handle,
rms.handle,
mom.handle,
- learning_rate,
+ lr_t,
rho,
momentum,
epsilon,
@@ -173,7 +172,7 @@
var.handle,
rms.handle,
mom.handle,
- learning_rate,
+ lr_t,
rho,
momentum,
epsilon,
@@ -182,12 +181,16 @@
use_locking=self._use_locking)
def get_config(self):
- config = super(RMSProp, self).get_config()
+ config = super(RMSprop, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
+ "decay": self._serialize_hyperparameter("decay"),
"rho": self._serialize_hyperparameter("rho"),
"momentum": self._serialize_hyperparameter("momentum"),
"epsilon": self._serialize_hyperparameter("epsilon"),
"centered": self._centered,
})
return config
+
+
+RMSProp = RMSprop
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
index 8d7afa5..a320cc0 100644
--- a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
@@ -28,6 +28,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.keras.optimizer_v2 import rmsprop
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
@@ -52,7 +53,7 @@
]
-class RMSPropOptimizerTest(test.TestCase):
+class RMSpropOptimizerTest(test.TestCase):
def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, rho, momentum,
epsilon, centered):
@@ -87,7 +88,7 @@
def testDense(self):
for (dtype, learning_rate, rho, momentum, epsilon, centered) in _TESTPARAMS:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
@@ -98,7 +99,7 @@
var1 = resource_variable_ops.ResourceVariable(var1_np, dtype=dtype)
grads0 = constant_op.constant(grads0_np, dtype=dtype)
grads1 = constant_op.constant(grads1_np, dtype=dtype)
- opt = rmsprop.RMSProp(
+ opt = rmsprop.RMSprop(
learning_rate=learning_rate,
rho=rho,
momentum=momentum,
@@ -106,7 +107,7 @@
centered=centered)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
if centered:
mg0 = opt.get_slot(var0, "mg")
@@ -135,9 +136,9 @@
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
- # Run 4 steps of RMSProp
+ # Run 4 steps of RMSprop
for _ in range(1, 5):
- update.run()
+ self.evaluate(update)
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, rho,
@@ -157,6 +158,73 @@
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ def testDenseWithLearningRateDecay(self):
+ var0_np = np.array([1.0, 2.0])
+ grads0_np = np.array([0.1, 0.2])
+ var1_np = np.array([3.0, 4.0])
+ grads1_np = np.array([0.01, 0.2])
+
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ learning_rate = 0.01
+ rho = 0.9
+ momentum = 0.0
+ epsilon = 1e-7
+ centered = False
+ decay = 0.5
+ opt = rmsprop.RMSprop(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered,
+ decay=decay)
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ self.evaluate(variables.global_variables_initializer())
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ mg0_np = np.array([0.0, 0.0])
+ mg1_np = np.array([0.0, 0.0])
+ rms0_np = np.array([0.0, 0.0])
+ rms1_np = np.array([0.0, 0.0])
+ mom0_np = np.array([0.0, 0.0])
+ mom1_np = np.array([0.0, 0.0])
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Run 4 steps of RMSprop
+ for t in range(2):
+ self.evaluate(update)
+
+ lr = learning_rate / (1 + decay * t)
+ var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
+ var0_np, grads0_np, mg0_np, rms0_np, mom0_np, lr, rho, momentum,
+ epsilon, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
+ var1_np, grads1_np, mg1_np, rms1_np, mom1_np, lr, rho, momentum,
+ epsilon, centered)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0))
+ self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1))
+ self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0))
+ self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1))
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session():
@@ -164,18 +232,18 @@
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
loss = pred * pred
- sgd_op = rmsprop.RMSProp(
+ sgd_op = rmsprop.RMSprop(
learning_rate=1.0,
rho=0.0,
momentum=0.0,
epsilon=0.0,
centered=False).minimize(
loss, var_list=[var0])
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0))
# Run 1 step of sgd
- sgd_op.run()
+ self.evaluate(sgd_op)
# Validate updated params
self.assertAllCloseAccordingToType([[0., 1.]],
self.evaluate(var0),
@@ -188,18 +256,18 @@
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
loss = pred * pred
- sgd_op = rmsprop.RMSProp(
+ sgd_op = rmsprop.RMSprop(
learning_rate=1.0,
rho=0.0,
momentum=0.0,
epsilon=1.0,
centered=True).minimize(
loss, var_list=[var0])
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0))
# Run 1 step of sgd
- sgd_op.run()
+ self.evaluate(sgd_op)
# Validate updated params
self.assertAllCloseAccordingToType([[-111, -138]],
self.evaluate(var0),
@@ -207,7 +275,7 @@
def testSparse(self):
for (dtype, learning_rate, rho, momentum, epsilon, centered) in _TESTPARAMS:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
@@ -224,14 +292,14 @@
grads1 = ops.IndexedSlices(
constant_op.constant(grads1_np),
constant_op.constant(grads1_np_indices), constant_op.constant([1]))
- opt = rmsprop.RMSProp(
+ opt = rmsprop.RMSprop(
learning_rate=learning_rate,
rho=rho,
momentum=momentum,
epsilon=epsilon,
centered=centered)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
if centered:
mg0 = opt.get_slot(var0, "mg")
@@ -261,9 +329,9 @@
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
- # Run 4 steps of RMSProp
+ # Run 4 steps of RMSprop
for _ in range(1, 5):
- update.run()
+ self.evaluate(update)
var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
@@ -295,7 +363,7 @@
rho = lambda: 0.9
momentum = lambda: 0.0
epsilon = lambda: 1.0
- opt = rmsprop.RMSProp(learning_rate, rho, momentum, epsilon)
+ opt = rmsprop.RMSprop(learning_rate, rho, momentum, epsilon)
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index 09dd708..9c8020d 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -22,8 +22,16 @@
import six
from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python import tf2
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2
+from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2
+from tensorflow.python.keras.optimizer_v2 import adam as adam_v2
+from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2
+from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
+from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.ops import clip_ops
@@ -796,16 +804,27 @@
Returns:
A Keras Optimizer instance.
"""
- all_classes = {
- 'sgd': SGD,
- 'rmsprop': RMSprop,
- 'adagrad': Adagrad,
- 'adadelta': Adadelta,
- 'adam': Adam,
- 'adamax': Adamax,
- 'nadam': Nadam,
- 'tfoptimizer': TFOptimizer,
- }
+ if tf2.enabled():
+ all_classes = {
+ 'adadelta': adadelta_v2.Adadelta,
+ 'adagrad': adagrad_v2.Adagrad,
+ 'adam': adam_v2.Adam,
+ 'adamax': adamax_v2.Adamax,
+ 'nadam': nadam_v2.Nadam,
+ 'rmsprop': rmsprop_v2.RMSprop,
+ 'sgd': gradient_descent_v2.SGD
+ }
+ else:
+ all_classes = {
+ 'adadelta': Adadelta,
+ 'adagrad': Adagrad,
+ 'adam': Adam,
+ 'adamax': Adamax,
+ 'nadam': Nadam,
+ 'rmsprop': RMSprop,
+ 'sgd': SGD,
+ 'tfoptimizer': TFOptimizer
+ }
# Make deserialization case-insensitive for built-in optimizers.
if config['class_name'].lower() in all_classes:
config['class_name'] = config['class_name'].lower()
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 9664f09..46bb027 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -19,11 +19,14 @@
from __future__ import print_function
import gc
+import os
import weakref
+from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
+from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@@ -208,5 +211,40 @@
_ = keras.optimizers.Adam(clipnorm=-2.0)
+@test_util.run_all_in_graph_and_eager_modes
+class KerasV2OptimizersTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ('adadelta_tf2', 'adadelta', True), ('adadelta_tf1', 'adadelta', False),
+ ('adagrad_tf2', 'adagrad', True), ('adagrad_tf1', 'adagrad', False),
+ ('adam_tf2', 'adam', True), ('adam_tf1', 'adam', False),
+ ('adamax_tf2', 'adamax', True), ('adamax_tf1', 'adamax', False),
+ ('sgd_tf2', 'sgd', True), ('sgd_tf1', 'sgd', False),
+ ('nadam_tf2', 'nadam', True), ('nadam_tf1', 'nadam', False),
+ ('rmsprop_tf2', 'rmsprop', True), ('rmsprop_tf1', 'rmsprop', False))
+ def test_load_from_string(self, optimizer_string, tf2mode):
+ old_mode = os.environ.get('TF2_BEHAVIOR', None)
+ if tf2mode:
+ os.environ['TF2_BEHAVIOR'] = 'enabled'
+ else:
+ if 'TF2_BEHAVIOR' in os.environ:
+ del os.environ['TF2_BEHAVIOR']
+
+ # Sanity check.
+ self.assertEqual(tf2.enabled(), tf2mode)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(10,)))
+ model.compile(optimizer_string, 'binary_crossentropy')
+
+ self.assertEqual(optimizer_string,
+ model.optimizer.__class__.__name__.lower())
+
+ model.fit(np.ones((10, 10), 'float32'), np.ones((10, 1), 'float32'))
+
+ if old_mode is not None:
+ os.environ['TF2_BEHAVIOR'] = old_mode
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/regularizers.py b/tensorflow/python/keras/regularizers.py
index cbcdae2..28b6ad4 100644
--- a/tensorflow/python/keras/regularizers.py
+++ b/tensorflow/python/keras/regularizers.py
@@ -20,7 +20,6 @@
import six
-from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
@@ -55,14 +54,12 @@
self.l2 = K.cast_to_floatx(l2)
def __call__(self, x):
- if self.l1 or self.l2:
- regularization = ops.convert_to_tensor(0., dtype=K.floatx())
- if self.l1:
- regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x))
- if self.l2:
- regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x))
- return regularization
- return None
+ regularization = 0.
+ if self.l1:
+ regularization += math_ops.reduce_sum(self.l1 * math_ops.abs(x))
+ if self.l2:
+ regularization += math_ops.reduce_sum(self.l2 * math_ops.square(x))
+ return regularization
def get_config(self):
return {'l1': float(self.l1), 'l2': float(self.l2)}
diff --git a/tensorflow/python/keras/utils/__init__.py b/tensorflow/python/keras/utils/__init__.py
index 8939044..61940ad 100644
--- a/tensorflow/python/keras/utils/__init__.py
+++ b/tensorflow/python/keras/utils/__init__.py
@@ -34,6 +34,7 @@
from tensorflow.python.keras.utils.io_utils import HDF5Matrix
from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model
from tensorflow.python.keras.utils.layer_utils import get_source_inputs
+from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model
from tensorflow.python.keras.utils.np_utils import normalize
from tensorflow.python.keras.utils.np_utils import to_categorical
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index 158a9a5..60677be 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -77,7 +77,7 @@
Returns:
The total number of scalars composing the weights
"""
- return int(np.sum([np.prod(p.get_shape().as_list()) for p in set(weights)]))
+ return int(sum(np.prod(p.get_shape().as_list()) for p in set(weights)))
def print_summary(model, line_length=None, positions=None, print_fn=None):
diff --git a/tensorflow/python/keras/utils/losses_utils.py b/tensorflow/python/keras/utils/losses_utils.py
new file mode 100644
index 0000000..d11d785
--- /dev/null
+++ b/tensorflow/python/keras/utils/losses_utils.py
@@ -0,0 +1,213 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# pylint: disable=protected-access
+"""Utilities related to loss functions."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend as K
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import confusion_matrix
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('losses.Reduction', 'keras.losses.Reduction', v1=[])
+class ReductionV2(object):
+ """Types of loss reduction.
+
+ Contains the following values:
+ `NONE`: Un-reduced weighted losses with the same shape as input.
+ `SUM`: Scalar sum of weighted losses.
+ `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
+ """
+
+ NONE = None
+ SUM = 'sum'
+ SUM_OVER_BATCH_SIZE = 'sum_over_batch_size'
+
+ @classmethod
+ def all(cls):
+ return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
+
+ @classmethod
+ def validate(cls, key):
+ if key not in cls.all():
+ raise ValueError('Invalid Reduction Key %s.' % key)
+
+
+def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
+ """Squeeze or expand last dimension if needed.
+
+ 1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
+ (using `confusion_matrix.remove_squeezable_dimensions`).
+ 2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
+ from the new rank of `y_pred`.
+ If `sample_weight` is scalar, it is kept scalar.
+
+ This will use static shape if available. Otherwise, it will add graph
+ operations, which could result in a performance hit.
+
+ Args:
+ y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
+ y_true: Optional label `Tensor` whose dimensions match `y_pred`.
+ sample_weight: Optional weight scalar or `Tensor` whose dimensions match
+ `y_pred`.
+
+ Returns:
+ Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
+ the last dimension squeezed,
+ `sample_weight` could be extended by one dimension.
+ """
+ if y_true is not None:
+ # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
+ y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
+ y_true, y_pred)
+
+ if sample_weight is None:
+ return y_pred, y_true, None
+
+ sample_weight = ops.convert_to_tensor(sample_weight)
+ weights_shape = sample_weight.get_shape()
+ weights_rank = weights_shape.ndims
+ if weights_rank == 0: # If weights is scalar, do nothing.
+ return y_pred, y_true, sample_weight
+
+ y_pred_shape = y_pred.get_shape()
+ y_pred_rank = y_pred_shape.ndims
+ if (y_pred_rank is not None) and (weights_rank is not None):
+ # Use static rank.
+ if weights_rank - y_pred_rank == 1:
+ sample_weight = array_ops.squeeze(sample_weight, [-1])
+ elif y_pred_rank - weights_rank == 1:
+ sample_weight = array_ops.expand_dims(sample_weight, [-1])
+ return y_pred, y_true, sample_weight
+
+ # Use dynamic rank.
+ weights_rank_tensor = array_ops.rank(sample_weight)
+ rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
+ maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])
+
+ def _maybe_expand_weights():
+ return control_flow_ops.cond(
+ math_ops.equal(rank_diff,
+ -1), lambda: array_ops.expand_dims(sample_weight, [-1]),
+ lambda: sample_weight)
+
+ def _maybe_adjust_weights():
+ return control_flow_ops.cond(
+ math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
+ _maybe_expand_weights)
+
+ # squeeze or expand last dim of `sample_weight` if its rank differs by 1
+ # from the new rank of `y_pred`.
+ sample_weight = control_flow_ops.cond(
+ math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
+ _maybe_adjust_weights)
+ return y_pred, y_true, sample_weight
+
+
+def _safe_mean(losses, num_present):
+ """Computes a safe mean of the losses.
+
+ Args:
+ losses: `Tensor` whose elements contain individual loss measurements.
+ num_present: The number of measurable elements in `losses`.
+
+ Returns:
+ A scalar representing the mean of `losses`. If `num_present` is zero,
+ then zero is returned.
+ """
+ total_loss = math_ops.reduce_sum(losses)
+ return math_ops.div_no_nan(total_loss, num_present, name='value')
+
+
+def _num_elements(losses):
+ """Computes the number of elements in `losses` tensor."""
+ with ops.name_scope(None, 'num_elements', values=[losses]) as scope:
+ return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
+
+
+def _reduce_weighted_loss(weighted_losses,
+ reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
+ """Reduces the individual weighted loss measurements."""
+ if reduction == ReductionV2.NONE:
+ loss = weighted_losses
+ else:
+ loss = math_ops.reduce_sum(weighted_losses)
+ if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
+ loss = _safe_mean(loss, _num_elements(weighted_losses))
+ return loss
+
+
+def compute_weighted_loss(losses,
+ sample_weight=None,
+ reduction=ReductionV2.SUM_OVER_BATCH_SIZE,
+ name=None):
+ """Computes the weighted loss.
+
+ Args:
+ losses: `Tensor` of shape `[batch_size, d1, ... dN]`.
+ sample_weight: Optional `Tensor` whose rank is either 0, or the same rank as
+ `losses`, or be broadcastable to `losses`.
+ reduction: Type of `tf.losses.Reduction` to apply to loss. Default value is
+ `SUM_OVER_BATCH_SIZE`.
+ name: Optional name for the op.
+
+ Raises:
+ ValueError: If the shape of `sample_weight` is not compatible with `losses`.
+
+ Returns:
+ Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
+ `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
+ """
+ ReductionV2.validate(reduction)
+ if sample_weight is None:
+ sample_weight = 1.0
+ with ops.name_scope(name, 'weighted_loss', (losses, sample_weight)):
+ # Save the `reduction` argument for loss normalization when distributing
+ # to multiple replicas.
+ # TODO(josh11b): Associate it with the returned op for more precision.
+ ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
+
+ # Update dimensions of `sample_weight` to match with `losses` if possible.
+ losses, _, sample_weight = squeeze_or_expand_dimensions(
+ losses, None, sample_weight)
+ losses = ops.convert_to_tensor(losses)
+ input_dtype = losses.dtype
+ losses = math_ops.to_float(losses)
+ sample_weight = math_ops.to_float(sample_weight)
+
+ try:
+ # Broadcast weights if possible.
+ sample_weight = weights_broadcast_ops.broadcast_weights(
+ sample_weight, losses)
+ except ValueError:
+ # Reduce values to same ndim as weight array.
+ ndim = K.ndim(losses)
+ weight_ndim = K.ndim(sample_weight)
+ losses = K.mean(losses, axis=list(range(weight_ndim, ndim)))
+
+ sample_weight.get_shape().assert_is_compatible_with(losses.get_shape())
+ weighted_losses = math_ops.multiply(losses, sample_weight)
+ # Apply reduction function to the individual weighted losses.
+ loss = _reduce_weighted_loss(weighted_losses, reduction)
+ # Convert the result back to the input type.
+ loss = math_ops.cast(loss, input_dtype)
+ return loss
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 41099ba..de06ec6 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -270,7 +270,7 @@
],
)
-tf_py_test(
+cuda_py_test(
name = "ctc_loss_op_test",
size = "small",
srcs = ["ctc_loss_op_test.py"],
@@ -662,6 +662,18 @@
)
cuda_py_test(
+ name = "matrix_square_root_op_test",
+ size = "medium",
+ srcs = ["matrix_square_root_op_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:linalg_ops",
+ ],
+)
+
+cuda_py_test(
name = "matrix_solve_op_test",
size = "medium",
srcs = ["matrix_solve_op_test.py"],
@@ -819,6 +831,7 @@
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:io_ops",
"//tensorflow/python:io_ops_gen",
],
diff --git a/tensorflow/python/kernel_tests/aggregate_ops_test.py b/tensorflow/python/kernel_tests/aggregate_ops_test.py
index 0f15319..874d616 100644
--- a/tensorflow/python/kernel_tests/aggregate_ops_test.py
+++ b/tensorflow/python/kernel_tests/aggregate_ops_test.py
@@ -61,7 +61,7 @@
for dtype in self._supported_types():
for count in range(1, self._MAX_N + 1):
data = [self._buildData((2, 2), dtype) for _ in range(count)]
- actual = sess.run(math_ops.add_n(data))
+ actual = self.evaluate(math_ops.add_n(data))
expected = np.sum(np.vstack(
[np.expand_dims(d, 0) for d in data]), axis=0)
tol = 5e-3 if dtype == dtypes.float16 else 5e-7
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index b9d9d54..afc158f 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -556,7 +556,8 @@
def testInt64GPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
- with self.session(use_gpu=True, force_gpu=True):
+
+ with test_util.force_gpu():
x = constant_op.constant([1., 2., 3.])
begin = constant_op.constant([2], dtype=dtypes.int64)
end = constant_op.constant([3], dtype=dtypes.int64)
@@ -1187,18 +1188,18 @@
self.assertAllEqual(x.numpy(), y.numpy())
self.assertTrue(device in y.device.lower())
- with ops.device("gpu:0"):
+ with test_util.force_gpu():
a = constant_op.constant([[2], [3]], dtype=dtypes.float32)
- with ops.device("gpu:0"):
+ with test_util.force_gpu():
b = array_ops.identity(a)
_test(a, b, "gpu")
- with ops.device("cpu:0"):
+ with test_util.force_cpu():
c = array_ops.identity(b)
_test(b, c, "cpu")
- with ops.device("cpu:0"):
+ with test_util.force_cpu():
d = array_ops.identity(c)
_test(c, d, "cpu")
- with ops.device("gpu:0"):
+ with test_util.force_gpu():
e = array_ops.identity(d)
_test(d, e, "gpu")
diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py
index 14db06b..00dba99 100644
--- a/tensorflow/python/kernel_tests/attention_ops_test.py
+++ b/tensorflow/python/kernel_tests/attention_ops_test.py
@@ -85,7 +85,7 @@
# Evaluate the TensorFlow Graph.
with self.cached_session() as sess:
- value_rows, value_cols = sess.run([glimpse_rows, glimpse_cols])
+ value_rows, value_cols = self.evaluate([glimpse_rows, glimpse_cols])
# Check dimensions of returned glimpse.
self.assertEqual(value_rows.shape[1], glimpse_sizes[0])
diff --git a/tensorflow/python/kernel_tests/barrier_ops_test.py b/tensorflow/python/kernel_tests/barrier_ops_test.py
index 4d36b3a..495bbe7 100644
--- a/tensorflow/python/kernel_tests/barrier_ops_test.py
+++ b/tensorflow/python/kernel_tests/barrier_ops_test.py
@@ -229,7 +229,7 @@
insert_ops = [b.insert_many(0, [k], [v]) for k, v in zip(keys, values)]
take_t = b.take_many(10)
- sess.run(insert_ops)
+ self.evaluate(insert_ops)
self.assertEquals(size_t.eval(), [10])
indices_val, keys_val, values_val = sess.run(
@@ -491,9 +491,9 @@
b = data_flow_ops.Barrier(
(dtypes.float32, dtypes.float32), shapes=((), ()), name="B")
take_t = b.take_many(1, allow_small_batch=True)
- sess.run(b.close(cancel))
+ self.evaluate(b.close(cancel))
with self.assertRaisesOpError("is closed and has insufficient elements"):
- sess.run(take_t)
+ self.evaluate(take_t)
def testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(self):
self._testClosedEmptyBarrierTakeManyAllowSmallBatchRaises(cancel=False)
diff --git a/tensorflow/python/kernel_tests/base64_ops_test.py b/tensorflow/python/kernel_tests/base64_ops_test.py
index 1b39994..bb903d8 100644
--- a/tensorflow/python/kernel_tests/base64_ops_test.py
+++ b/tensorflow/python/kernel_tests/base64_ops_test.py
@@ -93,7 +93,7 @@
decoded = string_ops.decode_base64(encoded)
with self.cached_session() as sess:
- encoded_value, decoded_value = sess.run([encoded, decoded])
+ encoded_value, decoded_value = self.evaluate([encoded, decoded])
self.assertEqual(encoded_value.shape, msg.shape)
self.assertEqual(decoded_value.shape, msg.shape)
diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py
index 5777a5d..bffa5e6 100644
--- a/tensorflow/python/kernel_tests/benchmark_test.py
+++ b/tensorflow/python/kernel_tests/benchmark_test.py
@@ -21,9 +21,12 @@
import os
import random
+import numpy as np
+
from tensorflow.core.util import test_log_pb2
from tensorflow.python.client import session
-from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -64,11 +67,17 @@
"other_key": "string"})
def benchmark_times_an_op(self):
+ input_size = 5
with session.Session(config=benchmark.benchmark_config()) as sess:
- a = constant_op.constant(0.0)
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=(input_size))
a_plus_a = a + a
return self.run_op_benchmark(
- sess, a_plus_a, min_iters=1000, store_trace=True, name="op_benchmark")
+ sess,
+ a_plus_a,
+ feed_dict={a: np.arange(input_size)},
+ min_iters=1000,
+ store_trace=True,
+ name="op_benchmark")
class BenchmarkTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/bitcast_op_test.py b/tensorflow/python/kernel_tests/bitcast_op_test.py
index 4dcf218..5ceffcf 100644
--- a/tensorflow/python/kernel_tests/bitcast_op_test.py
+++ b/tensorflow/python/kernel_tests/bitcast_op_test.py
@@ -21,6 +21,7 @@
import numpy as np
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -28,7 +29,7 @@
class BitcastTest(test.TestCase):
def _testBitcast(self, x, datatype, shape):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
tf_ans = array_ops.bitcast(x, datatype)
out = self.evaluate(tf_ans)
buff_after = memoryview(out).tobytes()
diff --git a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
index adfb094..1a7b1a7 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/quantile_ops_test.py
@@ -132,8 +132,8 @@
quantile_accumulator_handle_1, num_features=1)
quantiles = boosted_trees_ops.boosted_trees_bucketize(
[self._feature_0, self._feature_1], bucket_0 + bucket_1)
- sess.run([summary_op_0, summary_op_1])
- sess.run([flush_op_0, flush_op_1])
+ self.evaluate([summary_op_0, summary_op_1])
+ self.evaluate([flush_op_0, flush_op_1])
self.assertAllClose(self._feature_0_boundaries, bucket_0[0].eval())
self.assertAllClose(self._feature_1_boundaries, bucket_1[0].eval())
@@ -158,7 +158,7 @@
self._example_weights)
with ops.control_dependencies([summaries]):
flush = accumulator.flush()
- sess.run(flush)
+ self.evaluate(flush)
self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
save.save(sess, save_path)
@@ -185,12 +185,12 @@
summaries = accumulator.add_summaries([self._feature_0, self._feature_1],
self._example_weights)
- sess.run(summaries)
+ self.evaluate(summaries)
buckets = accumulator.get_bucket_boundaries()
self.assertAllClose([], buckets[0].eval())
self.assertAllClose([], buckets[1].eval())
save.save(sess, save_path)
- sess.run(accumulator.flush())
+ self.evaluate(accumulator.flush())
self.assertAllClose(self._feature_0_boundaries, buckets[0].eval())
self.assertAllClose(self._feature_1_boundaries, buckets[1].eval())
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index e4c5431..e1036b0 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -67,14 +67,14 @@
self.assertAllEqual([[1, 2], [1, 2]], self.evaluate(node_ids_list))
self.assertAllClose([[0.004775, 0.41184], [0.02823, 0.41184]],
- sess.run(gains_list))
+ self.evaluate(gains_list))
self.assertAllEqual([[1, 1], [1, 1]], self.evaluate(thresholds_list))
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
self.assertAllClose([[[-.416667], [.568966]], [[-.6], [-.75]]],
- sess.run(left_node_contribs_list))
+ self.evaluate(left_node_contribs_list))
self.assertAllClose([[[-.592593], [-.75]], [[-.076923], [.568966]]],
- sess.run(right_node_contribs_list))
+ self.evaluate(right_node_contribs_list))
def testCalculateBestGainsWithL2(self):
"""Testing Gain calculation with L2."""
@@ -115,14 +115,14 @@
self.assertAllEqual([[1, 2], [1, 2]], self.evaluate(node_ids_list))
self.assertAllClose([[0., 0.33931375], [0.01879096, 0.33931375]],
- sess.run(gains_list))
+ self.evaluate(gains_list))
self.assertAllEqual([[0, 1], [1, 1]], self.evaluate(thresholds_list))
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
self.assertAllClose([[[0.], [.485294]], [[-.5], [-.6]]],
- sess.run(left_node_contribs_list))
+ self.evaluate(left_node_contribs_list))
self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
- sess.run(right_node_contribs_list))
+ self.evaluate(right_node_contribs_list))
def testCalculateBestGainsWithL1(self):
"""Testing Gain calculation with L1."""
@@ -166,14 +166,14 @@
self.assertAllEqual([[1, 2], [1, 2]], self.evaluate(node_ids_list))
self.assertAllClose([[[0.0], [0.3965517]], [[-0.4], [-0.5]]],
- sess.run(left_node_contribs_list))
+ self.evaluate(left_node_contribs_list))
self.assertAllClose([[[-0.3333333], [-0.5]], [[0.0], [0.396552]]],
- sess.run(right_node_contribs_list))
+ self.evaluate(right_node_contribs_list))
# Gain should also include an adjustment of the gradient by l1.
self.assertAllClose([[0.0, 0.191207], [0.01, 0.191207]],
- sess.run(gains_list))
+ self.evaluate(gains_list))
def testCalculateBestGainsWithTreeComplexity(self):
"""Testing Gain calculation with L2."""
@@ -217,15 +217,15 @@
self.assertAllEqual([[1, 2], [1, 2]], self.evaluate(node_ids_list))
self.assertAllClose([[-3., -2.66068625], [-2.98120904, -2.66068625]],
- sess.run(gains_list))
+ self.evaluate(gains_list))
self.assertAllEqual([[0, 1], [1, 1]], self.evaluate(thresholds_list))
# The left node contrib will be later added to the previous node value to
# make the left node value, and the same for right node contrib.
self.assertAllClose([[[0.], [.485294]], [[-.5], [-.6]]],
- sess.run(left_node_contribs_list))
+ self.evaluate(left_node_contribs_list))
self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
- sess.run(right_node_contribs_list))
+ self.evaluate(right_node_contribs_list))
def testCalculateBestGainsWithMinNodeWeight(self):
"""Testing Gain calculation without any regularization."""
@@ -270,9 +270,9 @@
self.assertAllClose([[0.384314], [0.098013]], self.evaluate(gains_list))
self.assertAllEqual([[1], [1]], self.evaluate(thresholds_list))
self.assertAllClose([[[0.4852941]], [[-.6]]],
- sess.run(left_node_contribs_list))
+ self.evaluate(left_node_contribs_list))
self.assertAllClose([[[-0.75]], [[-0.014925]]],
- sess.run(right_node_contribs_list))
+ self.evaluate(right_node_contribs_list))
def testCalculateBestGainsWithMinNodeWeightNoSplitOnFeturePossible(self):
"""Testing Gain calculation without any regularization."""
diff --git a/tensorflow/python/kernel_tests/bucketize_op_test.py b/tensorflow/python/kernel_tests/bucketize_op_test.py
index 9575b28..f40ca82 100644
--- a/tensorflow/python/kernel_tests/bucketize_op_test.py
+++ b/tensorflow/python/kernel_tests/bucketize_op_test.py
@@ -56,7 +56,7 @@
with self.session(use_gpu=True) as sess:
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError, "Expected sorted boundaries"):
- sess.run(op)
+ self.evaluate(op)
def testBoundariesNotList(self):
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index cdeaf7b..2cfe084 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -229,7 +229,7 @@
[lo, lo + 1, lo // 2, hi // 2, hi - 1, hi], dtype=in_type)
y = math_ops.saturate_cast(x, dtype=out_type)
self.assertEqual(y.dtype, out_type)
- x, y = sess.run([x, y])
+ x, y = self.evaluate([x, y])
correct = np.maximum(out_type.min, np.minimum(out_type.max, x))
self.assertAllEqual(correct, y)
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index e96b277..1a509a4 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -97,7 +97,7 @@
class CholeskyOpTest(test.TestCase):
def _verifyCholeskyBase(self, sess, x, chol, verification):
- chol_np, verification_np = sess.run([chol, verification])
+ chol_np, verification_np = self.evaluate([chol, verification])
self.assertAllClose(x, verification_np)
self.assertShapeEqual(x, chol)
# Check that the cholesky is lower triangular, and has positive diagonal
@@ -183,8 +183,8 @@
matrix2 = math_ops.matmul(matrix2, matrix2, adjoint_a=True)
c1 = linalg_ops.cholesky(matrix1)
c2 = linalg_ops.cholesky(matrix2)
- c1_val, c2_val = sess.run([c1, c2])
- self.assertAllEqual(c1_val, c2_val)
+ c1_val, c2_val = self.evaluate([c1, c2])
+ self.assertAllClose(c1_val, c2_val)
class CholeskyGradTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/compare_and_bitpack_op_test.py b/tensorflow/python/kernel_tests/compare_and_bitpack_op_test.py
index e1928c5..215ea97 100644
--- a/tensorflow/python/kernel_tests/compare_and_bitpack_op_test.py
+++ b/tensorflow/python/kernel_tests/compare_and_bitpack_op_test.py
@@ -20,6 +20,7 @@
import numpy as np
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -30,7 +31,7 @@
x, threshold,
truth,
expected_err_re=None):
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
ans = math_ops.compare_and_bitpack(x, threshold)
if expected_err_re is None:
tf_ans = self.evaluate(ans)
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index 6944d73..27137f7 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -23,6 +23,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradient_checker
@@ -65,7 +66,7 @@
self.assertAllEqual(result[:, 4:], params[p2])
def testInt32GPU(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
p1 = np.random.rand(2, 3).astype("i")
p2 = np.random.rand(2, 3).astype("i")
x1 = constant_op.constant(p1)
@@ -76,13 +77,13 @@
self.assertAllEqual(result[2:, :], p2)
def testRefType(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
p1 = np.random.rand(4, 4).astype("f")
p2 = np.random.rand(4, 4).astype("f")
v1 = variables.Variable(p1)
v2 = variables.Variable(p2)
c = array_ops.concat([v1, v2], 0)
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
result = self.evaluate(c)
self.assertEqual(result.shape, c.get_shape())
@@ -172,7 +173,7 @@
# Test both positive and negative concat axis.
# -2 and 1 correspond to the same axis for 3-dimensional tensors.
for axis in [-2, 1]:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
inp = []
inp_tensors = []
for x in [1, 2, 6]:
@@ -203,7 +204,7 @@
self._testGradientsSimple(dtypes.complex64)
def testGradientsFirstDim(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
inp = []
inp_tensors = []
for x in [1, 2, 6]:
@@ -230,7 +231,7 @@
# Test both positive and negative concat axis.
# -1 and 2 correspond to the same axis for 3-dimensional tensors.
for axis in [-1, 2]:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
inp = []
inp_tensors = []
for x in [1, 2, 6]:
@@ -261,7 +262,7 @@
# Random dim to concat on
concat_dim = np.random.randint(5)
concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
inp = []
inp_tensors = []
for x in concat_dim_sizes:
@@ -358,7 +359,7 @@
def testZeroSize(self):
# Verify that concat doesn't crash and burn for zero size inputs
np.random.seed(7)
- with self.session(use_gpu=True) as sess:
+ with test_util.use_gpu():
for shape0 in (), (2,):
axis = len(shape0)
for shape1 in (), (3,):
@@ -370,10 +371,10 @@
# TODO(irving): Make tf.concat handle map, then drop list().
xs = list(map(constant_op.constant, [x0, x1]))
c = array_ops.concat(xs, axis)
- self.assertAllEqual(c.eval(), correct)
+ self.assertAllEqual(self.evaluate(c), correct)
# Check gradients
dc = np.random.randn(*c.get_shape().as_list())
- dxs = sess.run(gradients_impl.gradients(c, xs, dc))
+ dxs = self.evaluate(gradients_impl.gradients(c, xs, dc))
self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
def testTensorConcatDim0Grad(self):
@@ -473,18 +474,17 @@
def testConcatTuple(self):
c1 = np.random.rand(4, 4)
c2 = np.random.rand(4, 4)
- with self.cached_session():
- concat_list_t = array_ops.concat([c1, c2], 0)
- concat_tuple_t = array_ops.concat((c1, c2), 0)
- self.assertAllEqual(concat_list_t.eval(), self.evaluate(concat_tuple_t))
+ concat_list_t = array_ops.concat([c1, c2], 0)
+ concat_tuple_t = array_ops.concat((c1, c2), 0)
+ self.assertAllEqual(
+ self.evaluate(concat_list_t), self.evaluate(concat_tuple_t))
def testConcatNoScalars(self):
- with self.cached_session():
- scalar = constant_op.constant(7)
- dim = array_ops.placeholder(dtypes.int32)
- with self.assertRaisesRegexp(
- ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
- array_ops.concat([scalar, scalar, scalar], dim)
+ scalar = constant_op.constant(7)
+ dim = array_ops.placeholder(dtypes.int32)
+ with self.assertRaisesRegexp(
+ ValueError, r"Can't concatenate scalars \(use tf\.stack instead\)"):
+ array_ops.concat([scalar, scalar, scalar], dim)
# important as gpu implementation could fail if
# shared memory is not large for all the inputs
@@ -523,21 +523,21 @@
self.assertAllEqual(result[index], params[p[i]])
def testConcatEmpty(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
t1 = []
t2 = []
- output = gen_array_ops.concat_v2([t1, t2], 0).eval()
- self.assertFalse(output) # Checks that output is empty
+ output = gen_array_ops.concat_v2([t1, t2], 0)
+ self.assertFalse(self.evaluate(output)) # Checks that output is empty
def testConcatInvalidAxis(self):
with self.assertRaises(ValueError):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
t1 = [1]
t2 = [2]
gen_array_ops.concat_v2([t1, t2], 1).eval()
def testConcatNegativeAxis(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
@@ -608,7 +608,7 @@
def testConcatAxisType(self):
for dtype in [dtypes.int32, dtypes.int64]:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
@@ -621,7 +621,7 @@
class ConcatOffsetTest(test.TestCase):
def testBasic(self):
- with self.session(use_gpu=True) as sess:
+ with test_util.use_gpu():
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@@ -631,49 +631,45 @@
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
def testNotVector(self):
- with self.cached_session() as sess:
- cdim = constant_op.constant(1, dtypes.int32)
- s0 = constant_op.constant([[2, 3, 5]], dtypes.int32)
- s1 = constant_op.constant([[2, 7, 5]], dtypes.int32)
- off = gen_array_ops.concat_offset(cdim, [s0, s1])
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- r"should be a vector"):
- sess.run(off)
+ cdim = constant_op.constant(1, dtypes.int32)
+ s0 = constant_op.constant([[2, 3, 5]], dtypes.int32)
+ s1 = constant_op.constant([[2, 7, 5]], dtypes.int32)
+ off = gen_array_ops.concat_offset(cdim, [s0, s1])
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ r"should be a vector"):
+ self.evaluate(off)
def testConcatDimOutOfRange(self):
- with self.cached_session() as sess:
- cdim = constant_op.constant(4, dtypes.int32)
- s0 = constant_op.constant([2, 3, 5], dtypes.int32)
- s1 = constant_op.constant([2, 7, 5], dtypes.int32)
- off = gen_array_ops.concat_offset(cdim, [s0, s1])
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- r"Concat dim is out of range: 4 vs. 3"):
- sess.run(off)
+ cdim = constant_op.constant(4, dtypes.int32)
+ s0 = constant_op.constant([2, 3, 5], dtypes.int32)
+ s1 = constant_op.constant([2, 7, 5], dtypes.int32)
+ off = gen_array_ops.concat_offset(cdim, [s0, s1])
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ r"Concat dim is out of range: 4 vs. 3"):
+ self.evaluate(off)
def testDimMismatch(self):
- with self.cached_session() as sess:
- cdim = constant_op.constant(1, dtypes.int32)
- s0 = constant_op.constant([2, 3, 5], dtypes.int32)
- s1 = constant_op.constant([2, 7, 5, 10], dtypes.int32)
- off = gen_array_ops.concat_offset(cdim, [s0, s1])
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- r"should contain 3 elem"):
- sess.run(off)
+ cdim = constant_op.constant(1, dtypes.int32)
+ s0 = constant_op.constant([2, 3, 5], dtypes.int32)
+ s1 = constant_op.constant([2, 7, 5, 10], dtypes.int32)
+ off = gen_array_ops.concat_offset(cdim, [s0, s1])
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ r"should contain 3 elem"):
+ self.evaluate(off)
def testSizeMismatch(self):
- with self.cached_session() as sess:
- cdim = constant_op.constant(1, dtypes.int32)
- s0 = constant_op.constant([2, 3, 5], dtypes.int32)
- s1 = constant_op.constant([2, 7, 10], dtypes.int32)
- off = gen_array_ops.concat_offset(cdim, [s0, s1])
- with self.assertRaisesRegexp(
- errors_impl.InvalidArgumentError,
- r"All dimensions except 1 must match. Input 1 has shape \[2 7 10\] "
- r"and doesn't match input 0 with shape \[2 3 5\]."):
- sess.run(off)
+ cdim = constant_op.constant(1, dtypes.int32)
+ s0 = constant_op.constant([2, 3, 5], dtypes.int32)
+ s1 = constant_op.constant([2, 7, 10], dtypes.int32)
+ off = gen_array_ops.concat_offset(cdim, [s0, s1])
+ with self.assertRaisesRegexp(
+ errors_impl.InvalidArgumentError,
+ r"All dimensions except 1 must match. Input 1 has shape \[2 7 10\] "
+ r"and doesn't match input 0 with shape \[2 3 5\]."):
+ self.evaluate(off)
def testNegativeDim(self):
- with self.session(use_gpu=True) as sess:
+ with test_util.use_gpu():
cdim = constant_op.constant(-2, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
index 8388070..7ee1a4b 100644
--- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py
@@ -111,7 +111,7 @@
for e in elems:
q.apply_grad((e,)).run()
- result = sess.run(q.take_grad(1))
+ result = self.evaluate(q.take_grad(1))
self.assertEqual(sum(elems) / len(elems), result)
@@ -458,7 +458,7 @@
results = []
def take_grad():
- results.append(sess.run(takeg_t))
+ results.append(self.evaluate(takeg_t))
threads = [self.checkedThread(target=take_grad) for _ in range(10)]
@@ -490,7 +490,7 @@
return_array = []
def take_grad():
- return_array.append(sess.run(takeg_t))
+ return_array.append(self.evaluate(takeg_t))
accum_thread = self.checkedThread(target=apply_grad)
takeg_thread = self.checkedThread(target=take_grad)
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 112e201..9c3c96b 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -219,16 +219,28 @@
def testShapeInconsistent(self):
with ops.Graph().as_default():
- c = constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[10])
+ c = constant_op.constant_v1([1, 2, 3, 4, 5, 6, 7], shape=[10])
+ self.assertEqual(c.get_shape(), [10])
+
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ TypeError, "Expected Tensor's shape"):
+ c = constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[10])
+
+ def testPromotionShapes(self):
+ with ops.Graph().as_default():
+ c = constant_op.constant([7], shape=[10])
+ self.assertEqual(c.get_shape(), [10])
+ with ops.Graph().as_default():
+ c = constant_op.constant(3, shape=[10])
self.assertEqual(c.get_shape(), [10])
# pylint: disable=g-long-lambda
def testShapeWrong(self):
with ops.Graph().as_default():
- with self.assertRaisesWithPredicateMatch(
- ValueError,
- lambda e: ("Too many elements provided. Needed at most 5, "
- "but received 7" == str(e))):
+ with self.assertRaisesRegexp(ValueError, "Too many elements provided."):
+ constant_op.constant_v1([1, 2, 3, 4, 5, 6, 7], shape=[5])
+ with self.assertRaisesRegexp(TypeError, "Expected Tensor's shape"):
constant_op.constant([1, 2, 3, 4, 5, 6, 7], shape=[5])
# pylint: enable=g-long-lambda
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 0d6d2cc..37654ab 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -435,6 +435,20 @@
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
+ def testCondWithTensorArrayGrad(self):
+ with self.cached_session() as sess:
+ with ops.device(test.gpu_device_name()):
+ pred = array_ops.placeholder(dtypes.bool, [])
+ x = constant_op.constant([1.0, 2.0, 3.0])
+ y = control_flow_ops.cond(
+ pred, lambda: functional_ops.map_fn(lambda z: z * 2.0, x),
+ lambda: constant_op.constant([1.0, 1.0, 1.0]))
+ g = gradients_impl.gradients(y, x)[0]
+
+ self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
+ # TODO(b/119791601): Enable this.
+ # self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
+
@test_util.disable_control_flow_v2("b/113293074")
def testCondIndexedSlicesDifferentTypes(self):
with self.cached_session():
@@ -717,7 +731,7 @@
with ops.control_dependencies([v_t_op]):
orig_v = array_ops.identity(v)
merged_op = control_flow_ops.merge([assign_v, orig_v])
- self.assertAllEqual([1.0], sess.run(merged_op.output))
+ self.assertAllEqual([1.0], self.evaluate(merged_op.output))
def testCondSwitchIdentity(self):
# Make sure the recv identity is not removed by optimization.
@@ -731,7 +745,7 @@
return control_flow_ops.Assert(False, ["Wrong branch!!!"])
r = control_flow_ops.cond(pred, fn1, fn2)
- sess.run(r)
+ self.evaluate(r)
def testCondRecvIdentity(self):
# Make sure the switch identity is not removed by optimization.
@@ -747,7 +761,7 @@
return control_flow_ops.Assert(False, ["Wrong branch!!!"])
r = control_flow_ops.cond(pred, fn1, fn2)
- sess.run(r)
+ self.evaluate(r)
def testCondGrad_1(self):
with self.cached_session():
@@ -849,13 +863,13 @@
# Should just be [1, 1], but possibly a sparse representation
gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 1})
dense_gv = [
- sum([y for (x, y) in zip(gi, gv) if x == i]) for i in range(2)
+ sum(y for (x, y) in zip(gi, gv) if x == i) for i in range(2)
]
self.assertAllEqual(dense_gv, [1.0, 1.0])
# Should be [0, 2], as the else forwards v1[1] twice
gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 3})
dense_gv = [
- sum([y for (x, y) in zip(gi, gv) if x == i]) for i in range(2)
+ sum(y for (x, y) in zip(gi, gv) if x == i) for i in range(2)
]
self.assertAllEqual(dense_gv, [0.0, 2.0])
@@ -2119,7 +2133,7 @@
r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters)
grad_a, grad_v = gradients_impl.gradients(r, [a, v])
- grad_a_val, grad_v_val = sess.run([grad_a, grad_v])
+ grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v])
self.assertAllClose(216.0, grad_a_val)
self.assertAllClose(81.0, grad_v_val)
@@ -2250,7 +2264,7 @@
i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])
with self.cached_session() as sess:
- i_val, x_val = sess.run([i, x])
+ i_val, x_val = self.evaluate([i, x])
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
@@ -2279,7 +2293,7 @@
r_flattened = nest.flatten(r)
self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
- sess.run(r_flattened))
+ self.evaluate(r_flattened))
def testWhile_NestedBadArityFails(self):
with self.cached_session():
@@ -2596,7 +2610,7 @@
self.evaluate(variables.global_variables_initializer())
optimizer = gradient_descent.GradientDescentOptimizer(0.01)
op = optimizer.minimize(s)
- sess.run(op)
+ self.evaluate(op)
self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x))
@test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
@@ -2795,7 +2809,7 @@
self.assertAllClose([156.0, 400.0], sess.run(r, feed_dict=feed_dict))
name = "gradients/while/stopped_grad"
all_ops = x.graph.get_operations()
- self.assertFalse(any([name in op.name for op in all_ops]))
+ self.assertFalse(any(name in op.name for op in all_ops))
@test_util.disable_control_flow_v2("b/117954949")
def testWhileGradGradFail(self):
@@ -3019,19 +3033,19 @@
((x > y, a), (x > y, b)), default=c, exclusive=True)
variables.global_variables_initializer().run()
- self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3)
+ self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
self.assertEqual(2, self.evaluate(r2))
- self.assertAllEqual(sess.run([v0, v1, v2]), [-1, -1, 2])
+ self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2])
variables.global_variables_initializer().run()
- self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3)
+ self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
self.assertEqual(1, self.evaluate(r1))
- self.assertAllEqual(sess.run([v0, v1, v2]), [-1, 1, -1])
+ self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1])
variables.global_variables_initializer().run()
- self.assertAllEqual(sess.run([v0, v1, v2]), [-1] * 3)
+ self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
self.assertEqual(0, self.evaluate(r0))
- self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
+ self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1])
@test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testOneOpCond(self):
@@ -3069,7 +3083,7 @@
# Fetching v directly will result in an uninitialized error
with self.assertRaisesOpError("Attempting to use uninitialized value"):
- sess.run([c, v])
+ self.evaluate([c, v])
# Use a control dependency to ensure init_variable is run
# while asking for c
@@ -3077,7 +3091,7 @@
name="real_tensor",
output_tensor=v._ref(), # pylint: disable=protected-access
dependencies=[v.initializer])
- c_val, real_v_val = sess.run([c, real_v])
+ c_val, real_v_val = self.evaluate([c, real_v])
# Ensure the result of 'real_c' is the same as 'c'
self.assertAllEqual(10, c_val)
@@ -3170,7 +3184,7 @@
# Runs "init" before fetching v1 and v2.
init.run()
- v1_val, v2_val = sess.run([v1, v2])
+ v1_val, v2_val = self.evaluate([v1, v2])
# Ensure that v1 and v2 are initialized
self.assertAllClose([0.0], v1_val)
@@ -3331,7 +3345,7 @@
cond = constant_op.constant(True, dtypes.bool)
v_f, v_t = control_flow_ops.switch(constant_qint, cond)
result = control_flow_ops.merge([v_f, v_t])
- sess.run(result)
+ self.evaluate(result)
def testQIntRefSwitchMerge(self):
with self.cached_session(use_gpu=test.is_gpu_available()) as sess:
@@ -3344,7 +3358,15 @@
cond = constant_op.constant(True, dtypes.bool)
v_f, v_t = control_flow_ops.ref_switch(var_qint, cond)
result = control_flow_ops.ref_merge([v_f, v_t])
- sess.run(result)
+ self.evaluate(result)
+
+ def testUInt64SwitchMerge(self):
+ with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
+ constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64)
+ cond = constant_op.constant(True, dtypes.bool)
+ v_f, v_t = control_flow_ops.switch(constant_uint64, cond)
+ result = control_flow_ops.merge([v_f, v_t])
+ self.evaluate(result)
def testQIntArgAndRet(self):
@@ -3355,7 +3377,7 @@
with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
qint = constant_op.constant(np.array([42]), dtypes.qint8)
result = func(qint)
- sess.run(result)
+ self.evaluate(result)
class ControlFlowContextCheckTest(test.TestCase):
@@ -3679,11 +3701,11 @@
for _ in xrange(3):
# exclude warm up time
- sess.run(r)
+ self.evaluate(r)
start_time = time.time()
for _ in xrange(num_iters):
- sess.run(r)
+ self.evaluate(r)
return (time.time() - start_time) / num_iters
def benchmarkWhileOpCrossDevicePlacement(self):
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 2d21f6f..2f6f3bb 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -1774,10 +1774,10 @@
conv = nn_ops.conv2d(t1, t2, strides=strides, padding=padding)
os.environ["TF_USE_DEEP_CONV2D"] = "0"
- values_expect = sess.run([conv])
+ values_expect = self.evaluate([conv])
os.environ["TF_USE_DEEP_CONV2D"] = "1"
- values_test = sess.run([conv])
+ values_test = self.evaluate([conv])
self.assertAllClose(values_expect, values_test, rtol=1e-5, atol=1e-5)
diff --git a/tensorflow/python/kernel_tests/ctc_loss_op_test.py b/tensorflow/python/kernel_tests/ctc_loss_op_test.py
index cfc7cb9..36cae28 100644
--- a/tensorflow/python/kernel_tests/ctc_loss_op_test.py
+++ b/tensorflow/python/kernel_tests/ctc_loss_op_test.py
@@ -23,9 +23,15 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import ctc_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
@@ -52,6 +58,24 @@
return sparse_tensor.SparseTensor(x_ix, x_val, x_shape)
+def _ctc_loss_v2(labels, inputs, sequence_length,
+ preprocess_collapse_repeated=False,
+ ctc_merge_repeated=True,
+ ignore_longer_outputs_than_inputs=False,
+ time_major=True):
+ """Call ctc_loss_v2 with v1 args."""
+ assert not preprocess_collapse_repeated
+ assert ctc_merge_repeated
+ assert not ignore_longer_outputs_than_inputs
+ return ctc_ops.ctc_loss_v2(
+ labels=labels,
+ logits=inputs,
+ logit_length=sequence_length,
+ label_length=None,
+ blank_index=-1,
+ logits_time_major=time_major)
+
+
class CTCLossTest(test.TestCase):
def _testCTCLoss(self,
@@ -66,7 +90,7 @@
inputs_t = constant_op.constant(inputs)
with self.cached_session(use_gpu=False) as sess:
- loss = ctc_ops.ctc_loss(
+ loss = _ctc_loss_v2(
inputs=inputs_t, labels=labels, sequence_length=seq_lens)
grad = gradients_impl.gradients(loss, [inputs_t])[0]
@@ -74,12 +98,12 @@
self.assertShapeEqual(grad_truth, grad)
if expected_err_re is None:
- (tf_loss, tf_grad) = sess.run([loss, grad])
+ (tf_loss, tf_grad) = self.evaluate([loss, grad])
self.assertAllClose(tf_loss, loss_truth, atol=1e-6)
self.assertAllClose(tf_grad, grad_truth, atol=1e-6)
else:
with self.assertRaisesOpError(expected_err_re):
- sess.run([loss, grad])
+ self.evaluate([loss, grad])
def testBasic(self):
"""Test two batch entries."""
@@ -234,15 +258,15 @@
inputs_t_transposed = constant_op.constant(inputs.transpose(1, 0, 2))
with self.session(use_gpu=False) as sess:
- loss = ctc_ops.ctc_loss(
+ loss = _ctc_loss_v2(
inputs=inputs_t, labels=labels, sequence_length=seq_lens)
- loss_transposed = ctc_ops.ctc_loss(
+ loss_transposed = _ctc_loss_v2(
inputs=inputs_t_transposed,
labels=labels,
sequence_length=seq_lens,
time_major=False)
- (tf_loss, tf_loss_transposed) = sess.run([loss, loss_transposed])
+ (tf_loss, tf_loss_transposed) = self.evaluate([loss, loss_transposed])
self.assertAllEqual(tf_loss, tf_loss_transposed)
def testInvalidSecondGradient(self):
@@ -253,7 +277,7 @@
v = [1.0]
with self.session(use_gpu=False):
- loss = ctc_ops.ctc_loss(
+ loss = _ctc_loss_v2(
inputs=inputs_t, labels=labels, sequence_length=seq_lens)
# Taking ths second gradient should fail, since it is not
# yet supported.
@@ -272,7 +296,528 @@
with self.session(use_gpu=False) as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"batch_size must not be 0"):
- sess.run(ctc_ops.ctc_loss(labels, inputs, sequence_lengths))
+ sess.run(_ctc_loss_v2(labels, inputs, sequence_lengths))
+
+
+class CTCLossTestV2(test.TestCase):
+
+ def testCtcLossV2(self):
+ random_seed.set_random_seed(5)
+
+ batch_size = 8
+ num_labels = 6
+ max_label_length = 5
+ num_frames = 12
+
+ labels = random_ops.random_uniform(
+ [batch_size, max_label_length], minval=1, maxval=num_labels,
+ dtype=dtypes.int64)
+ logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
+
+ label_length = random_ops.random_uniform(
+ [batch_size], minval=2, maxval=max_label_length, dtype=dtypes.int64)
+ label_mask = array_ops.sequence_mask(
+ label_length, maxlen=max_label_length, dtype=label_length.dtype)
+ labels *= label_mask
+ logit_length = [num_frames] * batch_size
+
+ ref_loss = ctc_ops.ctc_loss_v2(
+ labels=labels,
+ logits=logits,
+ label_length=label_length,
+ logit_length=logit_length)
+ ref_grad = gradients_impl.gradients(ref_loss, [logits])
+
+ sparse_labels = ctc_ops.dense_labels_to_sparse(labels, label_length)
+
+ def assert_same_loss_and_grads(loss):
+ with self.cached_session() as sess:
+ self.assertAllClose(*self.evaluate([loss, ref_loss]))
+ grad = gradients_impl.gradients(loss, [logits])
+ self.assertAllClose(
+ *self.evaluate([grad, ref_grad]), rtol=2e-06, atol=2e-06)
+
+ assert_same_loss_and_grads(
+ ctc_ops.ctc_loss_v2(
+ labels=sparse_labels,
+ logits=logits,
+ label_length=label_length,
+ logit_length=logit_length,
+ blank_index=0))
+
+ def testCtcLossDenseIsSameAsCtcLoss(self):
+ with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
+ random_seed.set_random_seed(5)
+
+ batch_size = 8
+ num_labels = 6
+ label_length = 5
+ num_frames = 12
+ logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
+ labels = random_ops.random_uniform(
+ [batch_size, label_length], minval=1, maxval=num_labels,
+ dtype=dtypes.int64)
+
+ label_lengths = random_ops.random_uniform(
+ [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
+ label_mask = array_ops.sequence_mask(
+ label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
+ labels *= label_mask
+
+ logit_lengths = [num_frames] * batch_size
+
+ ctc_loss = ctc_ops.ctc_loss_dense(
+ labels=labels,
+ logits=logits,
+ label_length=label_lengths,
+ logit_length=logit_lengths)
+ ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
+
+ # Shift labels down by one (move blank from 0 to num_labels -1)
+ tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) - 1
+ tf_nn_ctc_logits = array_ops.concat([
+ logits[:, :, 1:],
+ logits[:, :, 0:1],
+ ], axis=2)
+
+ tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(
+ tf_ctc_loss_labels, label_lengths)
+
+ tf_nn_ctc_loss = ctc_ops.ctc_loss(
+ labels=tf_ctc_loss_labels,
+ inputs=tf_nn_ctc_logits,
+ sequence_length=logit_lengths,
+ time_major=True)
+ tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
+
+ with self.cached_session() as sess:
+ for _ in range(32):
+ self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
+ self.assertAllClose(
+ *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
+ rtol=2e-06,
+ atol=2e-06)
+
+ def testCtcLossDenseUniqueFastPathIsSameAsCtcLoss(self):
+ random_seed.set_random_seed(5)
+
+ batch_size = 8
+ num_labels = 6
+ label_length = 5
+ num_frames = 12
+ logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
+ labels = random_ops.random_uniform(
+ [batch_size, label_length], minval=1, maxval=num_labels,
+ dtype=dtypes.int64)
+
+ label_lengths = random_ops.random_uniform(
+ [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
+ label_mask = array_ops.sequence_mask(
+ label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
+ labels *= label_mask
+
+ logit_lengths = [num_frames] * batch_size
+
+ ctc_loss = ctc_ops.ctc_loss_dense(
+ labels=labels,
+ logits=logits,
+ label_length=label_lengths,
+ logit_length=logit_lengths,
+ unique=ctc_ops.ctc_unique_labels(labels))
+ ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
+
+ # Shift labels down by one (move blank from 0 to num_labels -1)
+ tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32) - 1
+ tf_nn_ctc_logits = array_ops.concat([
+ logits[:, :, 1:],
+ logits[:, :, 0:1],
+ ], axis=2)
+
+ tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(
+ tf_ctc_loss_labels, label_lengths)
+
+ tf_nn_ctc_loss = ctc_ops.ctc_loss(
+ labels=tf_ctc_loss_labels,
+ inputs=tf_nn_ctc_logits,
+ sequence_length=logit_lengths,
+ time_major=True)
+ tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
+
+ with self.cached_session() as sess:
+ for _ in range(32):
+ self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
+ self.assertAllClose(
+ *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
+ rtol=2e-06,
+ atol=2e-06)
+
+ def testCtcLossDenseWithBlankIndexIsSameAsCtcLoss(self):
+ random_seed.set_random_seed(5)
+
+ batch_size = 8
+ num_labels = 6
+ label_length = 5
+ num_frames = 12
+ logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
+ labels = random_ops.random_uniform(
+ [batch_size, label_length], minval=0, maxval=num_labels-1,
+ dtype=dtypes.int64)
+
+ label_lengths = random_ops.random_uniform(
+ [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
+ label_mask = array_ops.sequence_mask(
+ label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
+ labels *= label_mask
+
+ logit_lengths = [num_frames] * batch_size
+
+ tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32)
+ tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(
+ tf_ctc_loss_labels, label_lengths)
+
+ tf_nn_ctc_loss = ctc_ops.ctc_loss(
+ labels=tf_ctc_loss_labels,
+ inputs=logits,
+ sequence_length=logit_lengths,
+ time_major=True)
+ tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
+
+ # Shift the blank logits/labels to be somewhere in the middle.
+ blank_index = 2
+ shifted_logits = array_ops.concat([
+ logits[:, :, :blank_index],
+ logits[:, :, -1:],
+ logits[:, :, blank_index:-1],
+ ], axis=2)
+ shifted_labels = array_ops.where(labels < blank_index, labels, labels + 1)
+
+ ctc_loss = ctc_ops.ctc_loss_dense(
+ labels=shifted_labels,
+ logits=shifted_logits,
+ label_length=label_lengths,
+ logit_length=logit_lengths,
+ blank_index=blank_index)
+ ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
+
+ with self.cached_session() as sess:
+ for _ in range(32):
+ self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
+ self.assertAllClose(
+ *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
+ rtol=2e-06,
+ atol=2e-06)
+
+ def testCtcLossDenseWithNegativeBlankIndexIsSameAsCtcLoss(self):
+ with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
+ random_seed.set_random_seed(5)
+
+ batch_size = 8
+ num_labels = 6
+ label_length = 5
+ num_frames = 12
+ logits = random_ops.random_uniform([num_frames, batch_size, num_labels])
+ labels = random_ops.random_uniform(
+ [batch_size, label_length], minval=0, maxval=num_labels-1,
+ dtype=dtypes.int64)
+
+ label_lengths = random_ops.random_uniform(
+ [batch_size], minval=2, maxval=label_length, dtype=dtypes.int64)
+ label_mask = array_ops.sequence_mask(
+ label_lengths, maxlen=label_length, dtype=label_lengths.dtype)
+ labels *= label_mask
+
+ logit_lengths = [num_frames] * batch_size
+
+ ctc_loss = ctc_ops.ctc_loss_dense(
+ labels=labels,
+ logits=logits,
+ label_length=label_lengths,
+ logit_length=logit_lengths,
+ blank_index=-1)
+ ctc_loss_grads = gradients_impl.gradients(ctc_loss, [logits])[0]
+
+ tf_ctc_loss_labels = math_ops.cast(labels, dtypes.int32)
+ tf_ctc_loss_labels = ctc_ops.dense_labels_to_sparse(
+ tf_ctc_loss_labels, label_lengths)
+
+ tf_nn_ctc_loss = ctc_ops.ctc_loss(
+ labels=tf_ctc_loss_labels,
+ inputs=logits,
+ sequence_length=logit_lengths,
+ time_major=True)
+ tf_nn_ctc_grads = gradients_impl.gradients(tf_nn_ctc_loss, [logits])[0]
+
+ with self.cached_session() as sess:
+ for _ in range(32):
+ self.assertAllClose(*self.evaluate([ctc_loss, tf_nn_ctc_loss]))
+ self.assertAllClose(
+ *self.evaluate([ctc_loss_grads, tf_nn_ctc_grads]),
+ rtol=2e-06,
+ atol=2e-06)
+
+ def testCollapseRepeated(self):
+ collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
+ labels=[[1, 3, 3, 3, 0],
+ [1, 4, 4, 4, 0],
+ [4, 2, 2, 9, 4]],
+ seq_length=[4, 5, 5])
+ self.assertAllEqual(new_seq_lengths, [2, 3, 4])
+ self.assertAllEqual(
+ collapsed,
+ [[1, 3, 0, 0],
+ [1, 4, 0, 0],
+ [4, 2, 9, 4]])
+
+ def testCollapseRepeatedPreservesDtypes(self):
+ collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
+ labels=constant_op.constant(
+ [[1, 3, 3, 3, 0],
+ [1, 4, 4, 4, 0],
+ [4, 2, 2, 9, 4]],
+ dtype=dtypes.int64),
+ seq_length=constant_op.constant([4, 5, 5], dtype=dtypes.int64))
+ self.assertEqual(new_seq_lengths.dtype, dtypes.int64)
+ self.assertEqual(collapsed.dtype, dtypes.int64)
+ self.assertAllEqual(new_seq_lengths, [2, 3, 4])
+ self.assertAllEqual(
+ collapsed,
+ [[1, 3, 0, 0],
+ [1, 4, 0, 0],
+ [4, 2, 9, 4]])
+
+ def testCollapseRepeatedExtraPadding(self):
+ collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
+ labels=[[1, 3, 3, 3, 0, 0, 0],
+ [1, 4, 4, 4, 0, 1, 2],
+ [4, 2, 2, 9, 4, 0, 0]],
+ seq_length=[4, 5, 5])
+ self.assertAllEqual(new_seq_lengths, [2, 3, 4])
+ self.assertAllEqual(
+ collapsed,
+ [[1, 3, 0, 0],
+ [1, 4, 0, 0],
+ [4, 2, 9, 4]])
+
+ def testCollapseRepeatedFrontRepeats(self):
+ collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
+ labels=[[1, 1, 1, 2, 2],
+ [1, 1, 1, 2, 2],
+ [1, 1, 1, 2, 2]],
+ seq_length=[5, 4, 3])
+ self.assertAllEqual(new_seq_lengths, [2, 2, 1])
+ self.assertAllEqual(
+ collapsed,
+ [[1, 2],
+ [1, 2],
+ [1, 0]])
+
+ def testCollapseRepeatedAllLabelsTheSame(self):
+ collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
+ labels=[[1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1]],
+ seq_length=[4, 5, 1])
+ self.assertAllEqual(new_seq_lengths, [1, 1, 1])
+ self.assertAllEqual(
+ collapsed,
+ [[1],
+ [1],
+ [1]])
+
+ def testDenseSequencesToSparse(self):
+ labels = [[1, 3, 3, 3, 0],
+ [1, 4, 4, 4, 0],
+ [4, 2, 2, 9, 4]]
+ length = [4, 5, 5]
+ sparse = ctc_ops.dense_labels_to_sparse(labels, length)
+ new_dense = sparse_ops.sparse_tensor_to_dense(sparse)
+
+ self.assertAllEqual(labels, new_dense)
+
+ padded_labels = [[1, 3, 3, 3, 0, 0, 0, 0],
+ [1, 4, 4, 4, 0, 0, 0, 0],
+ [4, 2, 2, 9, 4, 0, 0, 0]]
+ length = [4, 5, 5]
+ sparse = ctc_ops.dense_labels_to_sparse(padded_labels, length)
+ padded_dense = sparse_ops.sparse_tensor_to_dense(sparse)
+
+ self.assertAllEqual(padded_dense, new_dense)
+
+ def testUnique(self):
+ labels = [
+ [3, 4, 4, 3],
+ [1, 1, 1, 0],
+ ]
+ unique, idx = ctc_ops.ctc_unique_labels(labels)
+ self.assertAllEqual([
+ [3, 4, 0, 0],
+ [1, 0, 0, 0],
+ ], unique)
+ self.assertAllEqual([
+ [0, 1, 1, 0],
+ [0, 0, 0, 1],
+ ], idx)
+
+ def testSumStates(self):
+ idx = [
+ [0, 1, 0, 1],
+ [0, 0, 0, 1],
+ ]
+ states = math_ops.log([
+ [[1.0, 2.0, 3.0, 4.0],
+ [5.0, 6.0, 7.0, 8.0]],
+ [[0.1, 0.2, 0.3, 0.4],
+ [0.5, 0.6, 0.7, 0.8]],
+ ])
+ sum_of_states = math_ops.exp(ctc_ops._sum_states(idx, states))
+ self.assertAllClose([
+ [[4.0, 6.0, 0.0, 0.0],
+ [18.0, 8.0, 0.0, 0.0]],
+ [[0.4, 0.6, 0.0, 0.0],
+ [1.8, 0.8, 0.0, 0.0]]
+ ], sum_of_states)
+
+ def testStateToOlabel(self):
+ labels = [
+ [3, 4, 3, 4],
+ [1, 1, 1, 0],
+ ]
+ num_labels = 8
+
+ # 3 frames, 2 batch, 10 states (5 label, 5 blank).
+ states = [
+ [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
+ [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]],
+ [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
+ [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]],
+ [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
+ [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]],
+ ]
+ labels = ops.convert_to_tensor(labels)
+ states = math_ops.log(states)
+ olabel = ctc_ops._state_to_olabel(labels, num_labels, states)
+ olabel = math_ops.exp(olabel)
+ blank = olabel[:, :, 0]
+ self.assertAllClose(blank, [
+ [0.16 + 0.17 + 0.18 + 0.19 + 0.20,
+ 0.26 + 0.27 + 0.28 + 0.29 + 0.30],
+ [1.6 + 1.7 + 1.8 + 1.9 + 2.0,
+ 2.6 + 2.7 + 2.8 + 2.9 + 3.0],
+ [16.0 + 17.0 + 18.0 + 19.0 + 20.0,
+ 26.0 + 27.0 + 28.0 + 29.0 + 30.0]
+ ])
+ self.assertAllClose(olabel[:, :, 1:], [
+ [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0],
+ [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0],
+ [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0],
+ [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
+ ])
+
+ def testStateToOlabelUnique(self):
+ labels = [
+ [3, 4, 3, 4],
+ [1, 1, 1, 0],
+ ]
+ num_labels = 8
+
+ # 3 frames, 2 batch, 10 states (5 label, 5 blank).
+ states = [
+ [[0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20],
+ [0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]],
+ [[1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0],
+ [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3.0]],
+ [[11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0],
+ [21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0]],
+ ]
+ labels = ops.convert_to_tensor(labels)
+ states = math_ops.log(states)
+ olabel = ctc_ops._state_to_olabel_unique(
+ labels, num_labels, states, ctc_ops.ctc_unique_labels(labels))
+ olabel = math_ops.exp(olabel)
+ blank = olabel[:, :, 0]
+ self.assertAllClose(blank, [
+ [0.16 + 0.17 + 0.18 + 0.19 + 0.20,
+ 0.26 + 0.27 + 0.28 + 0.29 + 0.30],
+ [1.6 + 1.7 + 1.8 + 1.9 + 2.0,
+ 2.6 + 2.7 + 2.8 + 2.9 + 3.0],
+ [16.0 + 17.0 + 18.0 + 19.0 + 20.0,
+ 26.0 + 27.0 + 28.0 + 29.0 + 30.0]])
+ self.assertAllClose(olabel[:, :, 1:], [
+ [[0.0, 0.0, 0.12 + 0.14, 0.13 + 0.15, 0.0, 0.0, 0.0],
+ [0.22 + 0.23 + 0.24, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 1.2 + 1.4, 1.3 + 1.5, 0.0, 0.0, 0.0],
+ [2.2 + 2.3 + 2.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 12.0 + 14.0, 13.0 + 15.0, 0.0, 0.0, 0.0],
+ [22.0 + 23.0 + 24.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
+ ])
+
+ def testScan(self):
+ with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
+ out = ctc_ops._scan(
+ lambda accum, elem: accum + elem,
+ constant_op.constant([1.0, 2.0, 3.0]), 23.0)
+ self.assertAllEqual([24.0, 26.0, 29.0], out)
+
+ out = ctc_ops._scan(
+ lambda a, e: a + e,
+ constant_op.constant([1.0, 2.0, 3.0]), 23.0,
+ inclusive=True)
+ self.assertAllEqual([23.0, 24.0, 26.0, 29.0], out)
+
+ out = ctc_ops._scan(
+ lambda a, e: a + e,
+ constant_op.constant([1.0, 2.0, 3.0]), 23.0,
+ reverse=True)
+ self.assertAllEqual([29.0, 28.0, 26.0], out)
+
+ out = ctc_ops._scan(
+ lambda a, e: a + e,
+ constant_op.constant([1.0, 2.0, 3.0]), 23.0,
+ reverse=True,
+ inclusive=True)
+ self.assertAllEqual([29.0, 28.0, 26.0, 23.0], out)
+
+ out = ctc_ops._scan(
+ lambda a, e: a + e,
+ constant_op.constant([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]),
+ constant_op.constant([23.0, 24.0]))
+ self.assertAllEqual([[23.0, 25.0], [25.0, 28.0], [29.0, 33.0]], out)
+
+ def testScanCapturesVariables(self):
+ with self.cached_session() as sess:
+ x = random_ops.random_uniform([])
+ fn = lambda accum, elem: accum + x * elem
+ out = ctc_ops._scan(fn, constant_op.constant([0.0, 1.0, 2.0]), 23.0)
+ self.assertAllEqual(*sess.run([
+ [23.0 + x * 0.0, 23.0 + x * 1.0, 23.0 + x * 3.0], out
+ ]))
+
+ def testScanMultipleAccumulators(self):
+ with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
+ def fn(accum, elem):
+ accum_a, accum_b = accum
+ return accum_a + elem, accum_b * elem
+ out = ctc_ops._scan(
+ fn, constant_op.constant([1.0, 2.0, 3.0]),
+ (23.0, constant_op.constant([1.0, 2.0])))
+ a, b = out
+ self.assertAllEqual([24.0, 26.0, 29.0], a)
+ self.assertAllEqual([[1.0, 2.0], [2.0, 4.0], [6.0, 12.0]], b)
+
+ def testScanMultipleElements(self):
+ with ops.device("/GPU:0" if test.is_gpu_available() else "/CPU:0"):
+ def fn(accum, elem):
+ elem_a, elem_b = elem
+ return accum + (elem_a * elem_b)
+ elems_a = constant_op.constant([1.0, 2.0, 3.0])
+ elems_b = constant_op.constant([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
+ out = ctc_ops._scan(
+ fn, (elems_a, elems_b),
+ initial=constant_op.constant([0.0, 0.0]))
+ self.assertAllEqual(
+ [[1.0, 2.0], [5.0, 8.0], [14.0, 20.0]], out)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
index df166b6..272c2b1 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_binary_test.py
@@ -77,23 +77,23 @@
def _compareCpu(self, x, y, np_func, tf_func, also_compare_variables=False):
np_ans = np_func(x, y)
- with self.test_session(use_gpu=False):
+ with test_util.force_cpu():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = tf_func(inx, iny)
tf_cpu = self.evaluate(out)
# Test that the op takes precedence over numpy operators.
- np_left = tf_func(x, iny).eval()
- np_right = tf_func(inx, y).eval()
+ np_left = self.evaluate(tf_func(x, iny))
+ np_right = self.evaluate(tf_func(inx, y))
if also_compare_variables:
var_x = variables.Variable(x)
var_y = variables.Variable(y)
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
print(type(x), type(y), type(var_x), type(var_y))
print(type(tf_func(x, var_y)), type(tf_func(var_x, y)))
- np_var_left = tf_func(x, var_y).eval()
- np_var_right = tf_func(var_x, y).eval()
+ np_var_left = self.evaluate(tf_func(x, var_y))
+ np_var_right = self.evaluate(tf_func(var_x, y))
if np_ans.dtype != np.object:
self.assertAllClose(np_ans, tf_cpu)
@@ -174,7 +174,7 @@
def _compareGpu(self, x, y, np_func, tf_func):
np_ans = np_func(x, y)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
+ with test_util.use_gpu():
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = tf_func(inx, iny)
@@ -252,10 +252,12 @@
y = np.array([1, 2]).reshape(2, 1).astype(np.int32)
var_x = variables.Variable(x)
var_y = variables.Variable(y)
+
with self.cached_session() as sess:
- sess.run([var_x.initializer, var_y.initializer])
- left_result = (var_x * y).eval()
- right_result = (x * var_y).eval()
+ self.evaluate([var_x.initializer, var_y.initializer])
+ left_result = self.evaluate(var_x * y)
+ right_result = self.evaluate(x * var_y)
+
np_result = x * y
self.assertAllEqual(np_result, left_result)
self.assertAllEqual(np_result, right_result)
@@ -382,10 +384,10 @@
def testStringComparison(self):
x = np.array([["abc", "bh"], ["c", ""]])
y = np.array([["abc", "bh"], ["def", "hi"]])
- with self.test_session(use_gpu=False) as sess:
+ with test_util.force_cpu():
cmp_eq = math_ops.equal(x, y)
cmp_not_eq = math_ops.not_equal(x, y)
- values = sess.run([cmp_eq, cmp_not_eq])
+ values = self.evaluate([cmp_eq, cmp_not_eq])
self.assertAllEqual([[True, True], [False, False]], values[0])
self.assertAllEqual([[False, False], [True, True]], values[1])
@@ -716,35 +718,35 @@
def testPowNegativeExponent(self):
for dtype in [np.int32, np.int64]:
- with self.test_session(use_gpu=False) as sess:
+ with test_util.force_cpu():
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
"Integers to negative integer powers are not allowed"):
x = np.array([5, 2]).astype(dtype)
y = np.array([-2, 3]).astype(dtype)
- sess.run(math_ops.pow(x, y))
+ self.evaluate(math_ops.pow(x, y))
- with self.test_session(use_gpu=False) as sess:
+ with test_util.force_cpu():
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
"Integers to negative integer powers are not allowed"):
x = np.array([5, 2]).astype(dtype)
y = np.array([2, -3]).astype(dtype)
- sess.run(math_ops.pow(x, y))
+ self.evaluate(math_ops.pow(x, y))
- with self.test_session(use_gpu=False) as sess:
+ with test_util.force_cpu():
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
"Integers to negative integer powers are not allowed"):
x = np.array([5, 2]).astype(dtype)
y = -3
- sess.run(math_ops.pow(x, y))
+ self.evaluate(math_ops.pow(x, y))
class ComparisonOpTest(test.TestCase):
def _compareScalar(self, func, x, y, dtype):
- with self.test_session(force_gpu=test_util.is_gpu_available()):
+ with test_util.use_gpu():
out = func(
ops.convert_to_tensor(np.array([x]).astype(dtype)),
ops.convert_to_tensor(np.array([y]).astype(dtype)))
@@ -777,7 +779,7 @@
def _compare(self, x, y, np_func, tf_func):
np_ans = np_func(x, y)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
+ with test_util.use_gpu():
out = tf_func(ops.convert_to_tensor(x), ops.convert_to_tensor(y))
tf_ans = self.evaluate(out)
self.assertAllEqual(np_ans, tf_ans)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 87248bf..7e14f95 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -84,7 +84,7 @@
class ComparisonOpTest(test.TestCase):
def _compareScalar(self, func, x, y, dtype):
- with self.test_session(force_gpu=test_util.is_gpu_available()):
+ with test_util.use_gpu():
out = func(
ops.convert_to_tensor(np.array([x]).astype(dtype)),
ops.convert_to_tensor(np.array([y]).astype(dtype)))
@@ -117,7 +117,7 @@
def _compare(self, x, y, np_func, tf_func):
np_ans = np_func(x, y)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
+ with test_util.use_gpu():
out = tf_func(ops.convert_to_tensor(x), ops.convert_to_tensor(y))
tf_ans = self.evaluate(out)
self.assertAllEqual(np_ans, tf_ans)
@@ -218,8 +218,7 @@
def _compareBinary(self, x, y, np_func, tf_func, use_gpu=False):
np_ans = np_func(x, y)
- with self.test_session(use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()):
+ with test_util.device(use_gpu=use_gpu):
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
out = tf_func(inx, iny)
@@ -230,8 +229,7 @@
def _not(self, x, use_gpu=False):
np_ans = np.logical_not(x)
- with self.test_session(use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()):
+ with test_util.device(use_gpu=use_gpu):
out = math_ops.logical_not(ops.convert_to_tensor(x))
tf_val = self.evaluate(out)
self.assertEqual(out.dtype, dtypes_lib.bool)
@@ -316,8 +314,7 @@
def _compare(self, c, x, y, use_gpu):
np_ans = np.where(c, x, y)
- with self.test_session(use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()):
+ with test_util.device(use_gpu=use_gpu):
out = array_ops.where(c, x, y)
tf_ans = self.evaluate(out)
self.assertAllEqual(np_ans, tf_ans)
@@ -460,8 +457,7 @@
np_ans = np.dstack(
[x_i if c_i else y_i for c_i, x_i, y_i in zip(c, x, y)]).transpose(
[2, 0, 1])
- with self.test_session(use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()):
+ with test_util.device(use_gpu=use_gpu):
out = array_ops.where(c, x, y)
tf_ans = self.evaluate(out)
self.assertAllEqual(np_ans, tf_ans)
@@ -566,13 +562,11 @@
def _compare(self, x, y, use_gpu):
np_min, np_max = np.minimum(x, y), np.maximum(x, y)
- with self.test_session(
- use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()) as sess:
+ with test_util.device(use_gpu=use_gpu):
inx = ops.convert_to_tensor(x)
iny = ops.convert_to_tensor(y)
omin, omax = math_ops.minimum(inx, iny), math_ops.maximum(inx, iny)
- tf_min, tf_max = sess.run([omin, omax])
+ tf_min, tf_max = self.evaluate([omin, omax])
self.assertAllEqual(np_min, tf_min)
self.assertAllEqual(np_max, tf_max)
@@ -641,13 +635,13 @@
class MathOpsOverloadTest(test.TestCase):
def _computeTensorAndLiteral(self, x, y, dtype, func):
- with self.test_session(use_gpu=False):
+ with test_util.force_cpu():
inx = ops.convert_to_tensor(x, dtype=dtype)
z = func(inx, y) # Should use __add__, __sub__, etc.
return self.evaluate(z)
def _computeLiteralAndTensor(self, x, y, dtype, func):
- with self.test_session(use_gpu=False):
+ with test_util.force_cpu():
iny = ops.convert_to_tensor(y, dtype=dtype)
z = func(x, iny) # Should use __radd__, __rsub__, etc.
return self.evaluate(z)
@@ -661,9 +655,9 @@
def _compareUnary(self, x, dtype, np_func, tf_func):
np_ans = np_func(x).astype(dtype.as_numpy_dtype)
- with self.test_session(use_gpu=False):
- self.assertAllClose(np_ans,
- tf_func(ops.convert_to_tensor(x, dtype=dtype)).eval())
+ with test_util.force_cpu():
+ self.assertAllClose(
+ np_ans, self.evaluate(tf_func(ops.convert_to_tensor(x, dtype=dtype))))
def testOverload(self):
dtypes = [
@@ -730,13 +724,11 @@
def _compare(self, x, use_gpu):
np_finite, np_inf, np_nan = np.isfinite(x), np.isinf(x), np.isnan(x)
- with self.test_session(
- use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()) as sess:
+ with test_util.device(use_gpu=use_gpu):
inx = ops.convert_to_tensor(x)
ofinite, oinf, onan = math_ops.is_finite(inx), math_ops.is_inf(
inx), math_ops.is_nan(inx)
- tf_finite, tf_inf, tf_nan = sess.run([ofinite, oinf, onan])
+ tf_finite, tf_inf, tf_nan = self.evaluate([ofinite, oinf, onan])
self.assertAllEqual(np_inf, tf_inf)
self.assertAllEqual(np_nan, tf_nan)
self.assertAllEqual(np_finite, tf_finite)
@@ -773,7 +765,7 @@
x = np.full((size,), value, dtype=dtype)
np_y = np.sqrt(x)
np_nan = np.isnan(np_y)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
+ with test_util.use_gpu():
tf_y = math_ops.sqrt(x)
tf_nan = math_ops.is_nan(tf_y)
if value < 0:
@@ -786,18 +778,20 @@
def _compare_values(self, x, y=None):
y = np.rint(x) if y is None else np.asarray(y)
- with self.cached_session() as sess:
- tf_rint = math_ops.rint(x)
- np_rint = self.evaluate(tf_rint)
+
+ tf_rint = math_ops.rint(x)
+ np_rint = self.evaluate(tf_rint)
+
self.assertAllEqual(y, np_rint)
self.assertShapeEqual(y, tf_rint)
def _compare(self, x):
np_floor, np_ceil = np.floor(x), np.ceil(x)
- with self.cached_session() as sess:
- inx = ops.convert_to_tensor(x)
- ofloor, oceil = math_ops.floor(inx), math_ops.ceil(inx)
- tf_floor, tf_ceil = sess.run([ofloor, oceil])
+
+ inx = ops.convert_to_tensor(x)
+ ofloor, oceil = math_ops.floor(inx), math_ops.ceil(inx)
+ tf_floor, tf_ceil = self.evaluate([ofloor, oceil])
+
self.assertAllEqual(np_floor, tf_floor)
self.assertAllEqual(np_ceil, tf_ceil)
self.assertShapeEqual(np_floor, ofloor)
@@ -828,12 +822,13 @@
def _compareMake(self, real, imag, use_gpu):
np_ans = real + (1j) * imag
- with self.test_session(use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()):
+
+ with test_util.device(use_gpu=use_gpu):
real = ops.convert_to_tensor(real)
imag = ops.convert_to_tensor(imag)
tf_ans = math_ops.complex(real, imag)
out = self.evaluate(tf_ans)
+
self.assertAllEqual(np_ans, out)
self.assertShapeEqual(np_ans, tf_ans)
@@ -848,8 +843,8 @@
def _compareRealImag(self, cplx, use_gpu):
np_real, np_imag = np.real(cplx), np.imag(cplx)
np_zeros = np_real * 0
- with self.test_session(use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()):
+
+ with test_util.device(use_gpu=use_gpu):
inx = ops.convert_to_tensor(cplx)
tf_real = math_ops.real(inx)
tf_imag = math_ops.imag(inx)
@@ -876,12 +871,12 @@
def _compareAngle(self, cplx, use_gpu):
np_angle = np.angle(cplx)
- with self.test_session(
- use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()) as sess:
+
+ with test_util.device(use_gpu=use_gpu):
inx = ops.convert_to_tensor(cplx)
tf_angle = math_ops.angle(inx)
tf_angle_val = self.evaluate(tf_angle)
+
self.assertAllEqual(np_angle, tf_angle_val)
self.assertShapeEqual(np_angle, tf_angle)
@@ -912,8 +907,7 @@
def _compareConj(self, cplx, use_gpu):
np_ans = np.conj(cplx)
- with self.test_session(use_gpu=use_gpu,
- force_gpu=use_gpu and test_util.is_gpu_available()):
+ with test_util.device(use_gpu=use_gpu):
inx = ops.convert_to_tensor(cplx)
tf_conj = math_ops.conj(inx)
tf_ans = self.evaluate(tf_conj)
diff --git a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
index 7096083..3e8294f 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_unary_test.py
@@ -76,7 +76,7 @@
if grad_atol is None:
grad_atol = _default_tolerance(x.dtype)
np_ans = np_func(x)
- with self.test_session(use_gpu=False):
+ with self.cached_session(use_gpu=False):
inx = ops.convert_to_tensor(x)
if x.dtype in (np.float32, np.float64,
dtypes_lib.bfloat16.as_numpy_dtype):
@@ -121,24 +121,22 @@
def _check(self, result_tensor, result_np, input_sp_t, tol):
self.assertTrue(isinstance(result_tensor, sparse_tensor.SparseTensor))
self.assertTrue(isinstance(input_sp_t, sparse_tensor.SparseTensor))
- self.assertAllEqual(input_sp_t.indices.eval(), result_tensor.indices.eval())
- self.assertAllEqual(input_sp_t.dense_shape.eval(),
- result_tensor.dense_shape.eval())
+ self.assertAllEqual(input_sp_t.indices, result_tensor.indices)
+ self.assertAllEqual(input_sp_t.dense_shape, result_tensor.dense_shape)
if tol is None:
- self.assertAllClose(result_np, result_tensor.values.eval())
+ self.assertAllClose(result_np, result_tensor.values)
else:
- self.assertAllClose(
- result_np, result_tensor.values.eval(), rtol=tol, atol=tol)
+ self.assertAllClose(result_np, result_tensor.values, rtol=tol, atol=tol)
def _compareSparseCpu(self, x, np_func, tf_func, tol):
x_sp, x_sp_vals = _sparsify(x)
res_np = np_func(x_sp_vals)
- with self.test_session(use_gpu=False):
+ with test_util.force_cpu():
self._check(tf_func(x_sp), res_np, x_sp, tol)
def _compareGpu(self, x, np_func, tf_func):
np_ans = np_func(x)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
+ with test_util.use_gpu():
result = tf_func(ops.convert_to_tensor(x))
tf_gpu = self.evaluate(result)
if x.dtype == np.float16:
@@ -150,7 +148,7 @@
def _compareSparseGpu(self, x, np_func, tf_func, tol):
x_sp, x_sp_vals = _sparsify(x)
res_np = np_func(x_sp_vals)
- with self.test_session(force_gpu=test_util.is_gpu_available()):
+ with test_util.use_gpu():
self._check(tf_func(x_sp), res_np, x_sp, tol)
def _compareBoth(self, x, np_func, tf_func):
diff --git a/tensorflow/python/kernel_tests/decode_image_op_test.py b/tensorflow/python/kernel_tests/decode_image_op_test.py
index 7a8743e..267afde 100644
--- a/tensorflow/python/kernel_tests/decode_image_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_image_op_test.py
@@ -40,7 +40,7 @@
bmp0 = io_ops.read_file(path)
image0 = image_ops.decode_image(bmp0)
image1 = image_ops.decode_bmp(bmp0)
- bmp0, image0, image1 = sess.run([bmp0, image0, image1])
+ bmp0, image0, image1 = self.evaluate([bmp0, image0, image1])
self.assertEqual(len(bmp0), 4194)
self.assertAllEqual(image0, image1)
@@ -56,7 +56,7 @@
gif0 = io_ops.read_file(path)
image0 = image_ops.decode_image(gif0)
image1 = image_ops.decode_gif(gif0)
- gif0, image0, image1 = sess.run([gif0, image0, image1])
+ gif0, image0, image1 = self.evaluate([gif0, image0, image1])
self.assertEqual(image0.shape, shape)
self.assertAllEqual(image0, image1)
@@ -85,7 +85,7 @@
jpeg0 = io_ops.read_file(path)
image0 = image_ops.decode_image(jpeg0)
image1 = image_ops.decode_jpeg(jpeg0)
- jpeg0, image0, image1 = sess.run([jpeg0, image0, image1])
+ jpeg0, image0, image1 = self.evaluate([jpeg0, image0, image1])
self.assertEqual(len(jpeg0), 3771)
self.assertEqual(image0.shape, (256, 128, 3))
self.assertAllEqual(image0, image1)
@@ -104,7 +104,7 @@
png0 = io_ops.read_file(path)
image0 = image_ops.decode_image(png0, channels=channels)
image1 = image_ops.decode_png(png0, channels=channels)
- png0, image0, image1 = sess.run([png0, image0, image1])
+ png0, image0, image1 = self.evaluate([png0, image0, image1])
self.assertEqual(image0.shape, (26, 51, channels or channels_in))
self.assertAllEqual(image0, image1)
diff --git a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
index 8c4ccbd..f8fc280 100644
--- a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
@@ -105,11 +105,11 @@
for _ in xrange(3):
# Skip warm up time.
- sess.run(r)
+ self.evaluate(r)
start_time = time.time()
for _ in xrange(num_iters):
- sess.run(r)
+ self.evaluate(r)
end_time = time.time()
return end_time - start_time
diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py
index c4bed11..19f1458 100644
--- a/tensorflow/python/kernel_tests/depthtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py
@@ -277,7 +277,7 @@
actual = array_ops.depth_to_space(t, block_size, data_format=data_format)
with self.session(use_gpu=use_gpu) as sess:
- actual_vals, expected_vals = sess.run([actual, expected])
+ actual_vals, expected_vals = self.evaluate([actual, expected])
self.assertTrue(np.array_equal(actual_vals, expected_vals))
def testAgainstTranspose(self):
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
index 78c1d74..d6ef9e7 100644
--- a/tensorflow/python/kernel_tests/determinant_op_test.py
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -23,6 +23,7 @@
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
@@ -62,7 +63,7 @@
atol=5e-5)
def _compareDeterminant(self, matrix_x):
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
self._compareDeterminantBase(matrix_x,
linalg_ops.matrix_determinant(matrix_x))
self._compareLogDeterminantBase(
@@ -155,7 +156,7 @@
matrix2 = random_ops.random_normal([5, 5], seed=42)
det1 = linalg_ops.matrix_determinant(matrix1)
det2 = linalg_ops.matrix_determinant(matrix2)
- det1_val, det2_val = sess.run([det1, det2])
+ det1_val, det2_val = self.evaluate([det1, det2])
self.assertEqual(det1_val, det2_val)
diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py
index 6b6de8b..0f800b9 100644
--- a/tensorflow/python/kernel_tests/distributions/special_math_test.py
+++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py
@@ -448,7 +448,7 @@
actual = sm.log_cdf_laplace(grid)
grad = gradients_impl.gradients(actual, grid)[0]
- actual_, grad_ = sess.run([actual, grad])
+ actual_, grad_ = self.evaluate([actual, grad])
# isfinite checks for NaN and Inf.
self.assertAllTrue(np.isfinite(actual_))
@@ -467,7 +467,7 @@
actual = sm.log_cdf_laplace(grid)
grad = gradients_impl.gradients(actual, grid)[0]
- actual_, grad_ = sess.run([actual, grad])
+ actual_, grad_ = self.evaluate([actual, grad])
# isfinite checks for NaN and Inf.
self.assertAllTrue(np.isfinite(actual_))
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index f4e651b..d3fa513 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -805,7 +805,7 @@
w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp(
logx, w, axis=-1, return_sign=True)
- [actual_, actual_sgn_] = sess.run([actual, actual_sgn])
+ [actual_, actual_sgn_] = self.evaluate([actual, actual_sgn])
self.assertAllEqual(expected, actual_)
self.assertAllEqual([-1., -1, 1], actual_sgn_)
@@ -823,7 +823,7 @@
w = constant_op.constant(w_)
actual, actual_sgn = du.reduce_weighted_logsumexp(
logx, w, axis=-1, return_sign=True, keep_dims=True)
- [actual_, actual_sgn_] = sess.run([actual, actual_sgn])
+ [actual_, actual_sgn_] = self.evaluate([actual, actual_sgn])
self.assertAllEqual(expected, actual_)
self.assertAllEqual([[-1.], [-1], [1]], actual_sgn_)
diff --git a/tensorflow/python/kernel_tests/division_future_test.py b/tensorflow/python/kernel_tests/division_future_test.py
index e477bdc..85c8580 100644
--- a/tensorflow/python/kernel_tests/division_future_test.py
+++ b/tensorflow/python/kernel_tests/division_future_test.py
@@ -65,7 +65,7 @@
tf_floordiv = tf_x // tf_y
check(floordiv, tf_floordiv)
# Do only one sess.run for speed
- for f, (x, y) in zip(checks, sess.run(tensors)):
+ for f, (x, y) in zip(checks, self.evaluate(tensors)):
f(x, y)
diff --git a/tensorflow/python/kernel_tests/division_past_test.py b/tensorflow/python/kernel_tests/division_past_test.py
index 63951b5..38bb186 100644
--- a/tensorflow/python/kernel_tests/division_past_test.py
+++ b/tensorflow/python/kernel_tests/division_past_test.py
@@ -64,7 +64,7 @@
tf_floordiv = tf_x // tf_y
check(floordiv, tf_floordiv)
# Do only one sess.run for speed
- for f, (x, y) in zip(checks, sess.run(tensors)):
+ for f, (x, y) in zip(checks, self.evaluate(tensors)):
f(x, y)
diff --git a/tensorflow/python/kernel_tests/draw_bounding_box_op_test.py b/tensorflow/python/kernel_tests/draw_bounding_box_op_test.py
index c655876..6aa757e 100644
--- a/tensorflow/python/kernel_tests/draw_bounding_box_op_test.py
+++ b/tensorflow/python/kernel_tests/draw_bounding_box_op_test.py
@@ -87,7 +87,7 @@
image = array_ops.expand_dims(image, 0)
image = image_ops.draw_bounding_boxes(image, bboxes)
with self.cached_session(use_gpu=False) as sess:
- op_drawn_image = np.squeeze(sess.run(image), 0)
+ op_drawn_image = np.squeeze(self.evaluate(image), 0)
self.assertAllEqual(test_drawn_image, op_drawn_image)
def testDrawBoundingBoxRGBColorCycling(self):
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 80da39d..3622fde 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -295,7 +295,7 @@
partitions = data_flow_ops.dynamic_partition(
data, indices, num_partitions=4)
with self.assertRaisesOpError(r"partitions\[2\] = 99 is not in \[0, 4\)"):
- sess.run(partitions)
+ self.evaluate(partitions)
def testScalarIndexOutOfRange(self):
with self.cached_session() as sess:
@@ -303,7 +303,7 @@
data = np.zeros(5)
partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7)
with self.assertRaisesOpError(r"partitions = 17 is not in \[0, 7\)"):
- sess.run(partitions)
+ self.evaluate(partitions)
def testHigherRankIndexOutOfRange(self):
with self.cached_session() as sess:
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
index c0b0e3f..3d063c4 100644
--- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py
@@ -22,6 +22,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
@@ -36,7 +37,7 @@
self.stitch_op = stitch_op
def testScalar(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40), constant_op.constant(60)]
for step in -1, 1:
@@ -47,7 +48,7 @@
self.assertEqual([2], stitched_t.get_shape().as_list())
def testShapeInferenceForScalarWithNonConstantIndices(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
indices = [
array_ops.placeholder(dtype=dtypes.int32),
constant_op.constant(1)
@@ -61,7 +62,7 @@
self.assertEqual([None], stitched_t.get_shape().as_list())
def testSimpleOneDimensional(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
# Test various datatypes in the simple case to ensure that the op was
# registered under those types.
dtypes_to_test = [
@@ -84,7 +85,7 @@
self.assertEqual([8], stitched_t.get_shape().as_list())
def testOneListOneDimensional(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
stitched_t = self.stitch_op(indices, data)
@@ -94,7 +95,7 @@
self.assertEqual([8], stitched_t.get_shape().as_list())
def testSimpleTwoDimensional(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
indices = [
constant_op.constant([0, 4, 7]),
constant_op.constant([1, 6]),
@@ -113,7 +114,7 @@
self.assertEqual([8, 2], stitched_t.get_shape().as_list())
def testZeroSizeTensor(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
indices = [
constant_op.constant([0, 4, 7]),
constant_op.constant([1, 6]),
@@ -222,7 +223,7 @@
DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch)
def testScalar(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
indices = [constant_op.constant(0), constant_op.constant(1)]
data = [constant_op.constant(40.0), constant_op.constant(60.0)]
for step in -1, 1:
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 443f54a..39c0575 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -294,7 +294,7 @@
variables.global_variables_initializer().run()
params_values = [params[p_i.name] for p_i in p]
# Test that the PartitionedVariable components equal the list in p
- p_var_val = sess.run(list(p_variable))
+ p_var_val = self.evaluate(list(p_variable))
# Actual test
tf_result = embedding.eval(feed_dict=feed_dict)
np_result, _, _ = _EmbeddingResult(params, id_vals, num_shards, vocab_size)
@@ -316,7 +316,7 @@
variables.global_variables_initializer().run()
params_values = [params[p_i.name] for p_i in p]
# Test that the PartitionedVariable components equal the list in p
- p_var_val = sess.run(list(p_variable))
+ p_var_val = self.evaluate(list(p_variable))
# Actual test
print(ops.get_default_graph().as_graph_def())
tf_result = self.evaluate(embedding)
@@ -758,11 +758,13 @@
assert num_shards > 0
assert num_shards <= vocab_size
- embedding_weights = partitioned_variables.create_partitioned_variables(
+ initializer = init_ops.truncated_normal_initializer(
+ mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32)
+ embedding_weights = list(variable_scope.get_variable(
+ name="embedding_weights",
shape=[vocab_size, embed_dim],
- slicing=[num_shards, 1],
- initializer=init_ops.truncated_normal_initializer(
- mean=0.0, stddev=1.0 / math.sqrt(vocab_size), dtype=dtypes.float32))
+ partitioner=partitioned_variables.fixed_size_partitioner(num_shards),
+ initializer=initializer))
for w in embedding_weights:
w.initializer.run()
embedding_weights = [w.eval() for w in embedding_weights]
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_op_test.py b/tensorflow/python/kernel_tests/extract_image_patches_op_test.py
index 4fe51e9..bb3c0ae 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_op_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_op_test.py
@@ -21,6 +21,7 @@
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -43,7 +44,7 @@
strides = [1] + strides + [1]
rates = [1] + rates + [1]
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
out_tensor = array_ops.extract_image_patches(
constant_op.constant(image),
ksizes=ksizes,
diff --git a/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py b/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
index d99823d..88f7df8 100644
--- a/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
+++ b/tensorflow/python/kernel_tests/extract_volume_patches_op_test.py
@@ -21,6 +21,7 @@
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -45,7 +46,7 @@
ksizes = [1] + ksizes + [1]
strides = [1] + strides + [1]
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
out_tensor = array_ops.extract_volume_patches(
constant_op.constant(image),
ksizes=ksizes,
diff --git a/tensorflow/python/kernel_tests/fifo_queue_test.py b/tensorflow/python/kernel_tests/fifo_queue_test.py
index c184b93..9655351 100644
--- a/tensorflow/python/kernel_tests/fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/fifo_queue_test.py
@@ -191,7 +191,7 @@
results = []
def dequeue():
- results.append(sess.run(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
for thread in threads:
@@ -246,7 +246,7 @@
def dequeue():
for _ in xrange(len(elems)):
- results.append(sess.run(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -552,7 +552,7 @@
dequeued_elems = []
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t))
+ dequeued_elems.extend(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in range(10)]
for thread in threads:
@@ -576,7 +576,7 @@
dequeued_elems = []
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t))
+ dequeued_elems.extend(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in range(10)]
for thread in threads:
@@ -704,7 +704,7 @@
self.evaluate(enqueue_op)
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t).tolist())
+ dequeued_elems.extend(self.evaluate(dequeued_t).tolist())
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -731,7 +731,7 @@
self.evaluate(enqueue_op)
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t).tolist())
+ dequeued_elems.extend(self.evaluate(dequeued_t).tolist())
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -801,7 +801,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -821,7 +821,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -846,7 +846,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -871,7 +871,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -918,7 +918,7 @@
def dequeue():
self.assertAllEqual(elems[0:3], self.evaluate(dequeued_t))
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
self.assertEqual(elems[3], self.evaluate(cleanup_dequeue_t))
def close():
@@ -955,7 +955,7 @@
def dequeue():
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run([dequeued_a_t, dequeued_b_t])
+ self.evaluate([dequeued_a_t, dequeued_b_t])
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -968,7 +968,7 @@
# Test that the elements in the partially-dequeued batch are
# restored in the correct order.
for elem_a, elem_b in zip(elems_a, elems_b):
- val_a, val_b = sess.run([cleanup_dequeue_a_t, cleanup_dequeue_b_t])
+ val_a, val_b = self.evaluate([cleanup_dequeue_a_t, cleanup_dequeue_b_t])
self.assertEqual(elem_a, val_a)
self.assertEqual(elem_b, val_b)
self.assertEqual(0, q.size().eval())
@@ -983,7 +983,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -1003,7 +1003,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -1321,7 +1321,7 @@
def blocking_enqueue():
enq_done.append(False)
# This will fill the queue and then block until enough dequeues happen.
- sess.run(enq)
+ self.evaluate(enq)
enq_done.append(True)
thread = self.checkedThread(target=blocking_enqueue)
@@ -1364,7 +1364,7 @@
def blocking_dequeue():
# Will only complete after 4 enqueues complete.
- results.extend(sess.run(deq))
+ results.extend(self.evaluate(deq))
thread = self.checkedThread(target=blocking_dequeue)
thread.start()
@@ -1373,7 +1373,7 @@
# TODO(mrry): Figure out how to do this without sleeping.
time.sleep(0.1)
self.assertEqual(len(results), 0)
- sess.run(enq)
+ self.evaluate(enq)
# Enough enqueued to unblock the dequeue
thread.join()
@@ -1508,9 +1508,9 @@
dequeue = q.dequeue()
dequeue_2 = q.dequeue_many(2)
self.evaluate(enqueue_op)
- sess.run(enqueue_op2)
- sess.run(enqueue_op3)
- sess.run(enqueue_op4)
+ self.evaluate(enqueue_op2)
+ self.evaluate(enqueue_op3)
+ self.evaluate(enqueue_op4)
f = sess.run(dequeue["f"])
self.assertEqual(10.0, f)
f = sess.run(dequeue_2["f"])
@@ -1566,9 +1566,9 @@
dequeue = q.dequeue()
dequeue_2 = q.dequeue_many(2)
self.evaluate(enqueue_op)
- sess.run(enqueue_op2)
- sess.run(enqueue_op3)
- sess.run(enqueue_op4)
+ self.evaluate(enqueue_op2)
+ self.evaluate(enqueue_op3)
+ self.evaluate(enqueue_op4)
i, f, s = sess.run([dequeue["i"], dequeue["f"], dequeue["s"]])
self.assertEqual(123, i)
self.assertEqual(10.0, f)
@@ -1597,7 +1597,7 @@
# until operation_timeout_in_ms.
with self.assertRaisesRegexp(errors_impl.DeadlineExceededError,
"Timed out waiting for notification"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
def testReusableAfterTimeout(self):
with self.cached_session() as sess:
diff --git a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
index cb7659a..272adec 100644
--- a/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_avg_pool_op_test.py
@@ -133,7 +133,7 @@
pseudo_random,
overlapping,
seed=self._SEED)
- actual, row_seq, col_seq = sess.run([p, r, c])
+ actual, row_seq, col_seq = self.evaluate([p, r, c])
expected = self._GetExpectedFractionalAvgPoolResult(input_tensor, row_seq,
col_seq, overlapping)
self.assertShapeEqual(expected, p)
@@ -164,7 +164,7 @@
pseudo_random,
overlapping,
seed=self._SEED)
- tensor_output, row_seq, col_seq = sess.run([p, r, c])
+ tensor_output, row_seq, col_seq = self.evaluate([p, r, c])
expected_result = self._GetExpectedFractionalAvgPoolResult(
rand_mat.astype(np.float32), row_seq, col_seq, overlapping)
print("row sequence:")
diff --git a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
index 0427e34..9b1e73b 100644
--- a/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
+++ b/tensorflow/python/kernel_tests/fractional_max_pool_op_test.py
@@ -133,7 +133,7 @@
pseudo_random,
overlapping,
seed=self._SEED)
- actual, row_seq, col_seq = sess.run([p, r, c])
+ actual, row_seq, col_seq = self.evaluate([p, r, c])
expected = self._GetExpectedFractionalMaxPoolResult(input_tensor, row_seq,
col_seq, overlapping)
self.assertShapeEqual(expected, p)
@@ -164,7 +164,7 @@
pseudo_random,
overlapping,
seed=self._SEED)
- tensor_output, row_seq, col_seq = sess.run([p, r, c])
+ tensor_output, row_seq, col_seq = self.evaluate([p, r, c])
expected_result = self._GetExpectedFractionalMaxPoolResult(rand_mat,
row_seq,
col_seq,
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 0af32b0..23b3c7e 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -458,7 +458,7 @@
grad = gradients_impl.gradients(ys=[loss], xs=[a, b])
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
- sess.run(grad)
+ self.evaluate(grad)
@test_util.run_in_graph_and_eager_modes
def testFoldShape(self):
@@ -769,7 +769,7 @@
else:
fetch = "my_while:1"
with self.session(graph=g, use_gpu=use_gpu) as sess:
- return sess.run(fetch)
+ return self.evaluate(fetch)
self.assertAllEqual(Run(20., False), 210.)
self.assertAllEqual(Run(20., True), 210.)
@@ -1194,7 +1194,7 @@
log_device_placement=True,
device_count={"CPU": 2})) as sess:
self.evaluate(variables.global_variables_initializer())
- expected = sess.run(sum_gather())
+ expected = self.evaluate(sum_gather())
result = sess.run(
functional_ops.partitioned_call(
args=defined.captured_inputs, f=defined))
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 074985d..87c7bbe 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -704,7 +704,7 @@
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
with self.session(use_gpu=True) as sess:
- sess.run(my_ops)
+ self.evaluate(my_ops)
# Check the shape of the outputs
t = self.evaluate(outputs)
self.assertAllEqual(t.shape, outputs_shape)
@@ -842,7 +842,7 @@
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
with self.session(use_gpu=True) as sess:
- sess.run(my_ops)
+ self.evaluate(my_ops)
# Check the shape of the outputs
t = self.evaluate(outputs)
self.assertAllEqual(t.shape, outputs_shape)
@@ -937,7 +937,7 @@
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
with self.session(use_gpu=True) as sess:
- sess.run(my_ops)
+ self.evaluate(my_ops)
# Check the shape of the outputs
t = self.evaluate(outputs)
self.assertAllEqual(t.shape, outputs_shape)
@@ -1062,7 +1062,7 @@
ratio = outputs_2norm / inputs_2norm
my_ops = variables.global_variables_initializer()
with self.cached_session(use_gpu=True) as sess:
- sess.run(my_ops)
+ self.evaluate(my_ops)
# Check the shape of the outputs
t = self.evaluate(outputs)
self.assertAllEqual(t.shape, outputs_shape)
diff --git a/tensorflow/python/kernel_tests/inplace_ops_test.py b/tensorflow/python/kernel_tests/inplace_ops_test.py
index 51d1686..e0c36d3 100644
--- a/tensorflow/python/kernel_tests/inplace_ops_test.py
+++ b/tensorflow/python/kernel_tests/inplace_ops_test.py
@@ -149,7 +149,7 @@
y = inplace_ops.alias_inplace_add(x, [0], [[1, 2, 3]])
with ops.control_dependencies([y]):
z = array_ops.identity(x)
- _, vy, vz = sess.run([x, y, z])
+ _, vy, vz = self.evaluate([x, y, z])
self.assertAllClose(vy, vz)
def testError(self):
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
index afa2419..a6b4770 100644
--- a/tensorflow/python/kernel_tests/io_ops_test.py
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -53,7 +53,7 @@
pass
with self.cached_session() as sess:
w = io_ops.write_file(temp.name, contents)
- sess.run(w)
+ self.evaluate(w)
with open(temp.name, 'rb') as f:
file_contents = f.read()
self.assertEqual(file_contents, contents)
@@ -67,7 +67,7 @@
filepath = os.path.join(subdir, 'subdir2', 'filename')
with self.cached_session() as sess:
w = io_ops.write_file(filepath, contents)
- sess.run(w)
+ self.evaluate(w)
with open(filepath, 'rb') as f:
file_contents = f.read()
self.assertEqual(file_contents, contents)
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index d5580d0..0986743 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -557,7 +557,7 @@
self.assertEqual(matrix_tensor.dtype,
linear_operator_circulant._DTYPE_COMPLEX)
matrix_h = linalg.adjoint(matrix_tensor)
- matrix, matrix_h = sess.run([matrix_tensor, matrix_h])
+ matrix, matrix_h = self.evaluate([matrix_tensor, matrix_h])
self.assertAllClose(matrix, matrix_h, atol=0)
def test_assert_non_singular_fails_for_singular_operator(self):
@@ -631,7 +631,7 @@
linear_operator_circulant._DTYPE_COMPLEX)
matrix_h = linalg.adjoint(matrix_tensor)
- matrix, matrix_h = sess.run([matrix_tensor, matrix_h])
+ matrix, matrix_h = self.evaluate([matrix_tensor, matrix_h])
self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), matrix.shape)
self.assertAllClose(matrix, matrix_h)
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
index 91f4097..80889a1 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
@@ -147,12 +147,12 @@
operator_matmul = operator.matmul(x)
mat_matmul = math_ops.matmul(mat, x)
self.assertAllEqual(operator_matmul.get_shape(), mat_matmul.get_shape())
- self.assertAllClose(*sess.run([operator_matmul, mat_matmul]))
+ self.assertAllClose(*self.evaluate([operator_matmul, mat_matmul]))
operator_solve = operator.solve(x)
mat_solve = linalg_ops.matrix_solve(mat, x)
self.assertAllEqual(operator_solve.get_shape(), mat_solve.get_shape())
- self.assertAllClose(*sess.run([operator_solve, mat_solve]))
+ self.assertAllClose(*self.evaluate([operator_solve, mat_solve]))
def test_diag_matmul(self):
operator1 = linalg_lib.LinearOperatorDiag([2., 3.])
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
index 522213e..e9fd91c 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
@@ -170,7 +170,7 @@
expected = x
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
- self.assertAllClose(*sess.run([operator_matmul, expected]))
+ self.assertAllClose(*self.evaluate([operator_matmul, expected]))
def test_default_batch_shape_broadcasts_with_everything_dynamic(self):
# These cannot be done in the automated (base test class) tests since they
@@ -207,7 +207,7 @@
operator_matmul = operator.matmul(x)
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
- self.assertAllClose(*sess.run([operator_matmul, expected]))
+ self.assertAllClose(*self.evaluate([operator_matmul, expected]))
def test_broadcast_matmul_dynamic_shapes(self):
# These cannot be done in the automated (base test class) tests since they
@@ -403,13 +403,13 @@
expected = x * 2.2 + zeros
operator_matmul = operator.matmul(x)
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
- self.assertAllClose(*sess.run([operator_matmul, expected]))
+ self.assertAllClose(*self.evaluate([operator_matmul, expected]))
# Test solve
expected = x / 2.2 + zeros
operator_solve = operator.solve(x)
self.assertAllEqual(operator_solve.get_shape(), expected.get_shape())
- self.assertAllClose(*sess.run([operator_solve, expected]))
+ self.assertAllClose(*self.evaluate([operator_solve, expected]))
def test_broadcast_matmul_and_solve_scalar_scale_multiplier(self):
# These cannot be done in the automated (base test class) tests since they
@@ -429,13 +429,13 @@
expected = x * 2.2
operator_matmul = operator.matmul(x)
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
- self.assertAllClose(*sess.run([operator_matmul, expected]))
+ self.assertAllClose(*self.evaluate([operator_matmul, expected]))
# Test solve
expected = x / 2.2
operator_solve = operator.solve(x)
self.assertAllEqual(operator_solve.get_shape(), expected.get_shape())
- self.assertAllClose(*sess.run([operator_solve, expected]))
+ self.assertAllClose(*self.evaluate([operator_solve, expected]))
def test_is_x_flags(self):
operator = linalg_lib.LinearOperatorScaledIdentity(
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
index 5ce2616..f127146 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
@@ -119,7 +119,7 @@
with self.cached_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
- x_bc_, y_bc_ = sess.run([x_bc, y_bc])
+ x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@@ -138,7 +138,7 @@
with self.cached_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
- x_bc_, y_bc_ = sess.run([x_bc, y_bc])
+ x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 2bc8ba4..1d9f403 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -29,8 +29,8 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@@ -806,7 +806,7 @@
l_read2 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
grad = gradients_impl.gradients([l_read1, l_read2], [x])
with self.cached_session() as sess:
- self.assertSequenceEqual(sess.run(grad), [2.])
+ self.assertSequenceEqual(self.evaluate(grad), [2.])
def testSkipEagerBuildElementShape(self):
fn = list_ops._build_element_shape
@@ -834,6 +834,37 @@
self.assertAllEqual(result[:2], [-1, 5])
self.assertIs(result[2], t)
+ def testAddN(self):
+ l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[])
+ l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
+ l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
+ result = math_ops.add_n((l1, l2, l3))
+ result_t = list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(result_t), [9., 12.])
+
+ def testAddNNestedList(self):
+ l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[])
+ l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
+ l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
+ l4 = list_ops.tensor_list_from_tensor([7.0, 8.0], element_shape=[])
+ a = list_ops.empty_tensor_list(
+ element_dtype=dtypes.variant, element_shape=[])
+ a = list_ops.tensor_list_push_back(a, l1)
+ a = list_ops.tensor_list_push_back(a, l2)
+ b = list_ops.empty_tensor_list(
+ element_dtype=dtypes.variant, element_shape=[])
+ b = list_ops.tensor_list_push_back(b, l3)
+ b = list_ops.tensor_list_push_back(b, l4)
+ result = math_ops.add_n((a, b))
+ result_0 = list_ops.tensor_list_stack(
+ list_ops.tensor_list_get_item(result, 0, element_dtype=dtypes.variant),
+ element_dtype=dtypes.float32)
+ result_1 = list_ops.tensor_list_stack(
+ list_ops.tensor_list_get_item(result, 1, element_dtype=dtypes.variant),
+ element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(result_0), [6., 8.])
+ self.assertAllEqual(self.evaluate(result_1), [10., 12.])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/listdiff_op_test.py b/tensorflow/python/kernel_tests/listdiff_op_test.py
index baeb40d..2865710 100644
--- a/tensorflow/python/kernel_tests/listdiff_op_test.py
+++ b/tensorflow/python/kernel_tests/listdiff_op_test.py
@@ -47,7 +47,7 @@
y_tensor = ops.convert_to_tensor(y, dtype=dtype)
out_tensor, idx_tensor = diff_func(x_tensor, y_tensor,
index_dtype=index_dtype)
- tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ tf_out, tf_idx = self.evaluate([out_tensor, idx_tensor])
self.assertAllEqual(tf_out, out)
self.assertAllEqual(tf_idx, idx)
self.assertEqual(1, out_tensor.get_shape().ndims)
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index ab4c9c7..79961d8 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -137,7 +137,7 @@
output2 = table2.lookup(input_string)
output3 = table3.lookup(input_string)
- out1, out2, out3 = sess.run([output1, output2, output3])
+ out1, out2, out3 = self.evaluate([output1, output2, output3])
self.assertAllEqual([0, 1, -1], out1)
self.assertAllEqual([0, 1, -1], out2)
self.assertAllEqual([0, 1, -1], out3)
@@ -995,7 +995,7 @@
output2 = table2.lookup(input_string)
output3 = table3.lookup(input_string)
- out1, out2, out3 = sess.run([output1, output2, output3])
+ out1, out2, out3 = self.evaluate([output1, output2, output3])
self.assertAllEqual([0, 1, -1], out1)
self.assertAllEqual([0, 1, -1], out2)
self.assertAllEqual([0, 1, -1], out3)
@@ -1313,7 +1313,7 @@
out1 = table1.lookup(input_string)
out2 = table2.lookup(input_string)
- out1, out2 = sess.run([out1, out2])
+ out1, out2 = self.evaluate([out1, out2])
self.assertAllEqual([5, 0, 1, 2, 5], out1)
self.assertAllEqual([5, 0, 1, 2, 3], out2)
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
@@ -1396,7 +1396,7 @@
out1 = table1.lookup(input_string_1)
out2 = table2.lookup(input_string_2)
- out1, out2 = sess.run([out1, out2])
+ out1, out2 = self.evaluate([out1, out2])
self.assertAllEqual([0, 1, 2, -1], out1)
self.assertAllEqual([-2, 1, -2], out2)
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
diff --git a/tensorflow/python/kernel_tests/map_stage_op_test.py b/tensorflow/python/kernel_tests/map_stage_op_test.py
index 4b5bd40..d503f3d 100644
--- a/tensorflow/python/kernel_tests/map_stage_op_test.py
+++ b/tensorflow/python/kernel_tests/map_stage_op_test.py
@@ -148,7 +148,7 @@
for i in range(n):
self.assertTrue(sess.run(peek, feed_dict={gi: i})[0] == i)
- self.assertTrue(self.evaluate(size) == 10)
+ self.assertTrue(sess.run(size) == 10)
def testSizeAndClear(self):
with ops.Graph().as_default() as G:
@@ -170,11 +170,11 @@
with self.session(use_gpu=True, graph=G) as sess:
sess.run(stage, feed_dict={x: -1, pi: 3})
- self.assertEqual(self.evaluate(size), 1)
+ self.assertEqual(sess.run(size), 1)
sess.run(stage, feed_dict={x: -1, pi: 1})
- self.assertEqual(self.evaluate(size), 2)
+ self.assertEqual(sess.run(size), 2)
sess.run(clear)
- self.assertEqual(self.evaluate(size), 0)
+ self.assertEqual(sess.run(size), 0)
def testCapacity(self):
capacity = 3
@@ -231,13 +231,13 @@
capacity))
# Should have capacity elements in the staging area
- self.assertTrue(self.evaluate(size) == capacity)
+ self.assertTrue(sess.run(size) == capacity)
# Clear the staging area completely
for i in range(n):
sess.run(get)
- self.assertTrue(self.evaluate(size) == 0)
+ self.assertTrue(sess.run(size) == 0)
def testMemoryLimit(self):
memory_limit = 512 * 1024 # 512K
@@ -295,13 +295,13 @@
capacity))
# Should have capacity elements in the staging area
- self.assertTrue(self.evaluate(size) == capacity)
+ self.assertTrue(sess.run(size) == capacity)
# Clear the staging area completely
for i in range(n):
sess.run(get)
- self.assertTrue(self.evaluate(size) == 0)
+ self.assertTrue(sess.run(size) == 0)
def testOrdering(self):
import six
@@ -332,14 +332,14 @@
for i in keys:
sess.run(stage, feed_dict={pi: i, x: i})
- self.assertTrue(self.evaluate(size) == n)
+ self.assertTrue(sess.run(size) == n)
# Check that key, values come out in ascending order
for i, k in enumerate(reversed(keys)):
- get_key, values = self.evaluate(get)
+ get_key, values = sess.run(get)
self.assertTrue(i == k == get_key == values)
- self.assertTrue(self.evaluate(size) == 0)
+ self.assertTrue(sess.run(size) == 0)
def testPartialDictInsert(self):
with ops.Graph().as_default() as G:
diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
index 7fe6cd4..83f4216 100644
--- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
@@ -25,6 +25,7 @@
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
@@ -50,7 +51,7 @@
def _verifyExponential(self, x, np_type):
inp = x.astype(np_type)
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
tf_ans = linalg_impl.matrix_exponential(inp)
if x.size == 0:
np_ans = np.empty(x.shape, dtype=np_type)
@@ -150,7 +151,7 @@
matrix2 = random_ops.random_normal([5, 5], seed=42)
expm1 = linalg_impl.matrix_exponential(matrix1)
expm2 = linalg_impl.matrix_exponential(matrix2)
- expm = sess.run([expm1, expm2])
+ expm = self.evaluate([expm1, expm2])
self.assertAllEqual(expm[0], expm[1])
diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
index 102502a..b0bce6a 100644
--- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
@@ -25,6 +25,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
@@ -39,7 +40,7 @@
def _verifyLogarithm(self, x, np_type):
inp = x.astype(np_type)
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
# Verify that expm(logm(A)) == A.
tf_ans = linalg_impl.matrix_exponential(
gen_linalg_ops.matrix_logarithm(inp))
@@ -128,7 +129,7 @@
random_ops.random_normal([5, 5], seed=42), dtypes.complex64)
logm1 = gen_linalg_ops.matrix_logarithm(matrix1)
logm2 = gen_linalg_ops.matrix_logarithm(matrix2)
- logm = sess.run([logm1, logm2])
+ logm = self.evaluate([logm1, logm2])
self.assertAllEqual(logm[0], logm[1])
diff --git a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
index 1f2144b..1e2109b 100644
--- a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py
@@ -21,6 +21,7 @@
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -31,7 +32,7 @@
def _verifySquareRoot(self, matrix, np_type):
matrix = matrix.astype(np_type)
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
# Verify that matmul(sqrtm(A), sqrtm(A)) = A
sqrt = gen_linalg_ops.matrix_square_root(matrix)
square = math_ops.matmul(sqrt, sqrt)
@@ -96,17 +97,18 @@
gen_linalg_ops.matrix_square_root(tensor)
def testNotSquare(self):
- with self.test_session():
- with self.assertRaises(ValueError):
- tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]])
- gen_linalg_ops.matrix_square_root(tensor).eval()
+ with self.assertRaises(ValueError):
+ tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]])
+ self.evaluate(gen_linalg_ops.matrix_square_root(tensor))
def testConcurrentExecutesWithoutError(self):
- with self.test_session(use_gpu=True) as sess:
+ with test_util.use_gpu():
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
- sqrt1 = gen_linalg_ops.matrix_square_root(matrix1)
- sqrt2 = gen_linalg_ops.matrix_square_root(matrix2)
+ square1 = math_ops.matmul(matrix1, matrix1)
+ square2 = math_ops.matmul(matrix2, matrix2)
+ sqrt1 = gen_linalg_ops.matrix_square_root(square1)
+ sqrt2 = gen_linalg_ops.matrix_square_root(square2)
all_ops = [sqrt1, sqrt2]
sqrt = self.evaluate(all_ops)
self.assertAllEqual(sqrt[0], sqrt[1])
diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py
index b683271..eb5f995 100644
--- a/tensorflow/python/kernel_tests/metrics_test.py
+++ b/tensorflow/python/kernel_tests/metrics_test.py
@@ -1643,11 +1643,11 @@
self.evaluate(variables.local_variables_initializer())
# Run several updates, then verify idempotency.
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
initial_prec = prec.eval()
initial_rec = rec.eval()
for _ in range(10):
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
self.assertAllClose(initial_prec, prec.eval())
self.assertAllClose(initial_rec, rec.eval())
@@ -1665,7 +1665,7 @@
thresholds)
self.evaluate(variables.local_variables_initializer())
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
self.assertEqual(1, prec.eval())
self.assertEqual(1, rec.eval())
@@ -1685,7 +1685,7 @@
thresholds)
self.evaluate(variables.local_variables_initializer())
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
self.assertAlmostEqual(0.5, prec.eval())
self.assertAlmostEqual(0.5, rec.eval())
@@ -1703,7 +1703,7 @@
thresholds)
self.evaluate(variables.local_variables_initializer())
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
self.assertAlmostEqual(0, prec.eval())
self.assertAlmostEqual(0, rec.eval())
@@ -1731,7 +1731,7 @@
rec_high = array_ops.reshape(rec_high, shape=())
self.evaluate(variables.local_variables_initializer())
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
self.assertAlmostEqual(1.0, prec_low.eval(), places=5)
self.assertAlmostEqual(0.0, prec_high.eval(), places=5)
@@ -1761,7 +1761,7 @@
rec_high = array_ops.reshape(rec_high, shape=())
self.evaluate(variables.local_variables_initializer())
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
self.assertAlmostEqual(1.0, prec_low.eval(), places=5)
self.assertAlmostEqual(0.0, prec_high.eval(), places=5)
@@ -1785,7 +1785,7 @@
value=rec, num_or_size_splits=2, axis=0)
self.evaluate(variables.local_variables_initializer())
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
self.assertAlmostEqual(0.75, prec_low.eval())
self.assertAlmostEqual(0.0, prec_high.eval())
@@ -1803,7 +1803,7 @@
thresholds)
self.evaluate(variables.local_variables_initializer())
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
self.assertAlmostEqual(0, prec.eval(), 6)
self.assertAlmostEqual(0, rec.eval(), 6)
@@ -1872,7 +1872,7 @@
self.evaluate(variables.local_variables_initializer())
for _ in range(int(num_samples / batch_size)):
- sess.run([prec_op, rec_op])
+ self.evaluate([prec_op, rec_op])
# Since this is only approximate, we can't expect a 6 digits match.
# Although with higher number of samples/thresholds we should see the
# accuracy improving
@@ -3056,10 +3056,10 @@
labels1, predictions1, name='msd1')
self.evaluate(variables.local_variables_initializer())
- sess.run([update_op0, update_op1])
- sess.run([update_op0, update_op1])
+ self.evaluate([update_op0, update_op1])
+ self.evaluate([update_op0, update_op1])
- mse0, mse1 = sess.run([mse0, mse1])
+ mse0, mse1 = self.evaluate([mse0, mse1])
self.assertAlmostEqual(208.0 / 6, mse0, 5)
self.assertAlmostEqual(79.0 / 6, mse1, 5)
@@ -3083,8 +3083,8 @@
mse, ms_update_op = metrics.mean_squared_error(labels, predictions)
self.evaluate(variables.local_variables_initializer())
- sess.run([ma_update_op, ms_update_op])
- sess.run([ma_update_op, ms_update_op])
+ self.evaluate([ma_update_op, ms_update_op])
+ self.evaluate([ma_update_op, ms_update_op])
self.assertAlmostEqual(32.0 / 6, mae.eval(), 5)
self.assertAlmostEqual(208.0 / 6, mse.eval(), 5)
@@ -3362,9 +3362,9 @@
pcnt2, update_op2 = metrics.percentage_below(values, 1, name='low')
self.evaluate(variables.local_variables_initializer())
- sess.run([update_op0, update_op1, update_op2])
+ self.evaluate([update_op0, update_op1, update_op2])
- pcnt0, pcnt1, pcnt2 = sess.run([pcnt0, pcnt1, pcnt2])
+ pcnt0, pcnt1, pcnt2 = self.evaluate([pcnt0, pcnt1, pcnt2])
self.assertAlmostEqual(1.0, pcnt0, 5)
self.assertAlmostEqual(0.75, pcnt1, 5)
self.assertAlmostEqual(0.0, pcnt2, 5)
@@ -3385,9 +3385,9 @@
self.evaluate(variables.local_variables_initializer())
self.assertListEqual([1.0, 0.5, 0.0],
- sess.run([update_op0, update_op1, update_op2]))
+ self.evaluate([update_op0, update_op1, update_op2]))
- pcnt0, pcnt1, pcnt2 = sess.run([pcnt0, pcnt1, pcnt2])
+ pcnt0, pcnt1, pcnt2 = self.evaluate([pcnt0, pcnt1, pcnt2])
self.assertAlmostEqual(1.0, pcnt0, 5)
self.assertAlmostEqual(0.5, pcnt1, 5)
self.assertAlmostEqual(0.0, pcnt2, 5)
diff --git a/tensorflow/python/kernel_tests/numerics_test.py b/tensorflow/python/kernel_tests/numerics_test.py
index d25d973..e3210dc 100644
--- a/tensorflow/python/kernel_tests/numerics_test.py
+++ b/tensorflow/python/kernel_tests/numerics_test.py
@@ -23,6 +23,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -35,7 +36,7 @@
def testVerifyTensorAllFiniteSucceeds(self):
x_shape = [5, 4]
x = np.random.random_sample(x_shape).astype(np.float32)
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
t_verified = numerics.verify_tensor_all_finite(t,
"Input is not a number.")
@@ -48,7 +49,7 @@
# Test NaN.
x[0] = np.nan
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
with self.assertRaisesOpError(my_msg):
t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
t_verified = numerics.verify_tensor_all_finite(t, my_msg)
@@ -56,7 +57,7 @@
# Test Inf.
x[0] = np.inf
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
with self.assertRaisesOpError(my_msg):
t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
t_verified = numerics.verify_tensor_all_finite(t, my_msg)
diff --git a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
index 3696298..b481836 100644
--- a/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
+++ b/tensorflow/python/kernel_tests/padding_fifo_queue_test.py
@@ -158,7 +158,7 @@
results = []
def dequeue():
- results.append(sess.run(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
for thread in threads:
@@ -199,7 +199,7 @@
def dequeue():
for _ in xrange(len(elems)):
- results.append(sess.run(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -656,7 +656,7 @@
dequeued_elems = []
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t))
+ dequeued_elems.extend(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in range(10)]
for thread in threads:
@@ -680,7 +680,7 @@
dequeued_elems = []
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t))
+ dequeued_elems.extend(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in range(10)]
for thread in threads:
@@ -808,7 +808,7 @@
self.evaluate(enqueue_op)
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t).tolist())
+ dequeued_elems.extend(self.evaluate(dequeued_t).tolist())
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -835,7 +835,7 @@
self.evaluate(enqueue_op)
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t).tolist())
+ dequeued_elems.extend(self.evaluate(dequeued_t).tolist())
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -905,7 +905,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -947,7 +947,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -972,7 +972,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -997,7 +997,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -1022,7 +1022,7 @@
def dequeue():
self.assertAllEqual(elems[0:3], self.evaluate(dequeued_t))
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
self.assertEqual(elems[3], self.evaluate(cleanup_dequeue_t))
def close():
@@ -1059,7 +1059,7 @@
def dequeue():
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run([dequeued_a_t, dequeued_b_t])
+ self.evaluate([dequeued_a_t, dequeued_b_t])
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -1072,7 +1072,7 @@
# Test that the elements in the partially-dequeued batch are
# restored in the correct order.
for elem_a, elem_b in zip(elems_a, elems_b):
- val_a, val_b = sess.run([cleanup_dequeue_a_t, cleanup_dequeue_b_t])
+ val_a, val_b = self.evaluate([cleanup_dequeue_a_t, cleanup_dequeue_b_t])
self.assertEqual(elem_a, val_a)
self.assertEqual(elem_b, val_b)
self.assertEqual(0, q.size().eval())
@@ -1087,7 +1087,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -1107,7 +1107,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -1434,7 +1434,7 @@
def blocking_enqueue():
enq_done.append(False)
# This will fill the queue and then block until enough dequeues happen.
- sess.run(enq)
+ self.evaluate(enq)
enq_done.append(True)
thread = self.checkedThread(target=blocking_enqueue)
@@ -1477,7 +1477,7 @@
def blocking_dequeue():
# Will only complete after 4 enqueues complete.
- results.extend(sess.run(deq))
+ results.extend(self.evaluate(deq))
thread = self.checkedThread(target=blocking_dequeue)
thread.start()
@@ -1486,7 +1486,7 @@
# TODO(mrry): Figure out how to do this without sleeping.
time.sleep(0.1)
self.assertEqual(len(results), 0)
- sess.run(enq)
+ self.evaluate(enq)
# Enough enqueued to unblock the dequeue
thread.join()
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index d87adbf..1f67710 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -1700,7 +1700,7 @@
json_tensor = constant_op.constant(["{]"])
binary_tensor = parsing_ops.decode_json_example(json_tensor)
with self.assertRaisesOpError("Error while parsing JSON"):
- sess.run(binary_tensor)
+ self.evaluate(binary_tensor)
class ParseTensorOpTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 61628c4..8122271 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -826,7 +826,7 @@
strides=[1, 1, 1, 1],
Targmax=dtypes.int64,
padding="VALID")
- out, argmax = sess.run([out_op, argmax_op])
+ out, argmax = self.evaluate([out_op, argmax_op])
self.assertShapeEqual(out, out_op)
self.assertShapeEqual(argmax, argmax_op)
self.assertAllClose(out.ravel(), [1.0, 1.0, 1.0, 1.0])
diff --git a/tensorflow/python/kernel_tests/priority_queue_test.py b/tensorflow/python/kernel_tests/priority_queue_test.py
index a510fcc..9be682e 100644
--- a/tensorflow/python/kernel_tests/priority_queue_test.py
+++ b/tensorflow/python/kernel_tests/priority_queue_test.py
@@ -215,7 +215,7 @@
# We can't guarantee full sorting because we can't guarantee
# that the dequeued.extend() call runs immediately after the
- # sess.run() call. Here we're just happy everything came out.
+ # self.evaluate() call. Here we're just happy everything came out.
self.assertAllEqual(set(dequeued), set(all_enqueued_values))
def testRoundTripInsertManyMultiThreadedReadOnceSorts(self):
diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index 114481e..305b5aa 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -129,7 +129,7 @@
q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_)
if use_static_shape_:
- q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
+ q_tf_val, r_tf_val = self.evaluate([q_tf, r_tf])
else:
q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
index 8d2718c..031a1c2 100644
--- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py
+++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
@@ -67,7 +67,7 @@
self.assertAllEqual([[1] * num_samples, [2] * num_samples], samples)
def testOneOpMultipleStepsIndependent(self):
- with self.test_session(use_gpu=True) as sess:
+ with test_util.use_gpu():
sample_op1, _ = self._make_ops(10)
# Consecutive runs shouldn't yield identical output.
sample1a = self.evaluate(sample_op1)
@@ -81,26 +81,26 @@
self.assertFalse(np.equal(sample1.numpy(), sample2.numpy()).all())
def testTwoOpsIndependent(self):
- with self.test_session(use_gpu=True) as sess:
+ with test_util.use_gpu():
sample_op1, sample_op2 = self._make_ops(32)
- sample1, sample2 = sess.run([sample_op1, sample_op2])
+ sample1, sample2 = self.evaluate([sample_op1, sample_op2])
# We expect sample1 and sample2 to be independent.
# 1 in 2^32 chance of this assertion failing.
self.assertFalse(np.equal(sample1, sample2).all())
def testTwoOpsSameSeedDrawSameSequences(self):
- with self.test_session(use_gpu=True) as sess:
+ with test_util.use_gpu():
sample_op1, sample_op2 = self._make_ops(1000, seed=1)
- sample1, sample2 = sess.run([sample_op1, sample_op2])
+ sample1, sample2 = self.evaluate([sample_op1, sample_op2])
self.assertAllEqual(sample1, sample2)
def testLargeLogits(self):
for neg in [True, False]:
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
logits = np.array([[1000.] * 5])
if neg:
logits *= -1
- samples = random_ops.multinomial(logits, 10).eval()
+ samples = self.evaluate(random_ops.multinomial(logits, 10))
# Sampled classes should be in-range.
self.assertTrue((samples >= 0).all())
self.assertTrue((samples < 5).all())
@@ -157,7 +157,7 @@
Returns:
Frequencies from sampled classes; shape [batch_size, num_classes].
"""
- with self.test_session(use_gpu=True) as sess:
+ with test_util.use_gpu():
random_seed.set_random_seed(1618)
op = sampler(constant_op.constant(logits), num_samples)
d = self.evaluate(op)
@@ -186,25 +186,26 @@
def testEmpty(self):
classes = 5
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
for batch in 0, 3:
for samples in 0, 7:
- x = random_ops.multinomial(
- array_ops.zeros([batch, classes]), samples).eval()
+ x = self.evaluate(
+ random_ops.multinomial(
+ array_ops.zeros([batch, classes]), samples))
self.assertEqual(x.shape, (batch, samples))
def testEmptyClasses(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
x = random_ops.multinomial(array_ops.zeros([5, 0]), 7)
with self.assertRaisesOpError("num_classes should be positive"):
self.evaluate(x)
def testNegativeMinLogits(self):
random_seed.set_random_seed(78844)
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
logits = constant_op.constant([[np.finfo(np.float32).min] * 1023 + [0]])
num_samples = 1000
- samples = random_ops.multinomial(logits, num_samples).eval()
+ samples = self.evaluate(random_ops.multinomial(logits, num_samples))
self.assertAllEqual([[1023] * num_samples], samples)
@@ -225,10 +226,8 @@
native_op = control_flow_ops.group(native_sampler(logits, num_samples))
composed_op = control_flow_ops.group(composed_sampler(logits, num_samples))
- native_dt = timeit.timeit(
- lambda: sess.run(native_op), number=num_iters)
- composed_dt = timeit.timeit(
- lambda: sess.run(composed_op), number=num_iters)
+ native_dt = timeit.timeit(lambda: sess.run(native_op), number=num_iters)
+ composed_dt = timeit.timeit(lambda: sess.run(composed_op), number=num_iters)
return native_dt, composed_dt
diff --git a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py
index 5601b98..ed4f543 100644
--- a/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py
+++ b/tensorflow/python/kernel_tests/random/random_shuffle_queue_test.py
@@ -86,7 +86,7 @@
for _ in range(2):
a, b = self.evaluate(dequeue_t)
results.append((a, b))
- a, b = sess.run(q.dequeue_many(3))
+ a, b = self.evaluate(q.dequeue_many(3))
for i in range(3):
results.append((a[i], b[i]))
self.assertItemsEqual([(1, [5]), (2, [6]), (3, [7]), (4, [8]), (9, [10])],
@@ -133,7 +133,7 @@
results = []
def dequeue():
- results.append(sess.run(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
for thread in threads:
@@ -173,7 +173,7 @@
def dequeue():
for _ in xrange(len(elems)):
- results.append(sess.run(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -466,7 +466,7 @@
dequeued_elems = []
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t))
+ dequeued_elems.extend(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in range(10)]
for thread in threads:
@@ -489,7 +489,7 @@
dequeued_elems = []
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t))
+ dequeued_elems.extend(self.evaluate(dequeued_t))
threads = [self.checkedThread(target=dequeue) for _ in range(10)]
for thread in threads:
@@ -542,7 +542,7 @@
self.evaluate(enqueue_op)
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t).tolist())
+ dequeued_elems.extend(self.evaluate(dequeued_t).tolist())
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -569,7 +569,7 @@
self.evaluate(enqueue_op)
def dequeue():
- dequeued_elems.extend(sess.run(dequeued_t).tolist())
+ dequeued_elems.extend(self.evaluate(dequeued_t).tolist())
enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -665,18 +665,18 @@
results = []
# Manually dequeue until we hit min_size.
- results.append(sess.run(dequeued_t))
- results.append(sess.run(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
def blocking_dequeue():
- results.append(sess.run(dequeued_t))
- results.append(sess.run(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
+ results.append(self.evaluate(dequeued_t))
self.assertItemsEqual(elems, results)
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=blocking_dequeue)
dequeue_thread.start()
@@ -701,7 +701,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
finished.append(True)
dequeue_thread = self.checkedThread(target=dequeue)
@@ -732,7 +732,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
progress.append(2)
self.assertEqual(len(progress), 0)
@@ -763,9 +763,9 @@
results = []
def dequeue():
- results.extend(sess.run(dequeued_t))
+ results.extend(self.evaluate(dequeued_t))
self.assertEquals(3, len(results))
- results.extend(sess.run(dequeued_t))
+ results.extend(self.evaluate(dequeued_t))
self.assertEquals(4, len(results))
dequeue_thread = self.checkedThread(target=dequeue)
@@ -794,11 +794,11 @@
results = []
def dequeue():
- results.extend(sess.run(dequeued_t))
+ results.extend(self.evaluate(dequeued_t))
self.assertEquals(3, len(results))
# min_after_dequeue is 2, we ask for 3 elements, and we end up only
# getting the remaining 1.
- results.extend(sess.run(dequeued_t))
+ results.extend(self.evaluate(dequeued_t))
self.assertEquals(4, len(results))
dequeue_thread = self.checkedThread(target=dequeue)
@@ -824,16 +824,16 @@
results = []
def dequeue():
- results.extend(sess.run(dequeued_t))
+ results.extend(self.evaluate(dequeued_t))
self.assertEqual(len(results), 3)
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
# While the last dequeue failed, we want to insure that it returns
# any elements that it potentially reserved to dequeue. Thus the
# next cleanup should return a single element.
- results.extend(sess.run(cleanup_dequeue_t))
+ results.extend(self.evaluate(cleanup_dequeue_t))
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -854,7 +854,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -874,7 +874,7 @@
# Expect the operation to fail due to the queue being closed.
with self.assertRaisesRegexp(errors_impl.OutOfRangeError,
"is closed and has insufficient"):
- sess.run(dequeued_t)
+ self.evaluate(dequeued_t)
dequeue_thread = self.checkedThread(target=dequeue)
dequeue_thread.start()
@@ -1383,7 +1383,7 @@
def blocking_enqueue():
enq_done.append(False)
# This will fill the queue and then block until enough dequeues happen.
- sess.run(enq)
+ self.evaluate(enq)
enq_done.append(True)
thread = self.checkedThread(target=blocking_enqueue)
@@ -1426,7 +1426,7 @@
def blocking_dequeue():
# Will only complete after 4 enqueues complete.
- results.extend(sess.run(deq))
+ results.extend(self.evaluate(deq))
thread = self.checkedThread(target=blocking_dequeue)
thread.start()
@@ -1435,7 +1435,7 @@
# TODO(mrry): Figure out how to do this without sleeping.
time.sleep(0.1)
self.assertEqual(len(results), 0)
- sess.run(enq)
+ self.evaluate(enq)
# Enough enqueued to unblock the dequeue
thread.join()
diff --git a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
index 13f97a9..071d6c2 100644
--- a/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/stateless_random_ops_test.py
@@ -24,6 +24,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops as stateless
@@ -58,11 +59,11 @@
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
preseed = preseed[::2] | preseed[1::2] << 32
random_seed.set_random_seed(seed[0])
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
for stateless_op, stateful_op in cases:
stateful = stateful_op(seed=seed[1])
pure = stateless_op(seed=preseed)
- self.assertAllEqual(stateful.eval(), self.evaluate(pure))
+ self.assertAllEqual(self.evaluate(stateful), self.evaluate(pure))
def _test_determinism(self, cases):
# Stateless values should be equal iff the seeds are equal (roughly)
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 4d9b26f..a4a18c5 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -140,147 +140,143 @@
class IdentityReaderTest(test.TestCase):
- def _ExpectRead(self, sess, key, value, expected):
- k, v = sess.run([key, value])
+ def _ExpectRead(self, key, value, expected):
+ k, v = self.evaluate([key, value])
self.assertAllEqual(expected, k)
self.assertAllEqual(expected, v)
def testOneEpoch(self):
- with self.cached_session() as sess:
- reader = io_ops.IdentityReader("test_reader")
- work_completed = reader.num_work_units_completed()
- produced = reader.num_records_produced()
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- queued_length = queue.size()
- key, value = reader.read(queue)
+ reader = io_ops.IdentityReader("test_reader")
+ work_completed = reader.num_work_units_completed()
+ produced = reader.num_records_produced()
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ queued_length = queue.size()
+ key, value = reader.read(queue)
- self.assertAllEqual(0, self.evaluate(work_completed))
- self.assertAllEqual(0, self.evaluate(produced))
- self.assertAllEqual(0, self.evaluate(queued_length))
+ self.assertAllEqual(0, self.evaluate(work_completed))
+ self.assertAllEqual(0, self.evaluate(produced))
+ self.assertAllEqual(0, self.evaluate(queued_length))
- queue.enqueue_many([["A", "B", "C"]]).run()
- queue.close().run()
- self.assertAllEqual(3, self.evaluate(queued_length))
+ self.evaluate(queue.enqueue_many([["A", "B", "C"]]))
+ self.evaluate(queue.close())
+ self.assertAllEqual(3, self.evaluate(queued_length))
- self._ExpectRead(sess, key, value, b"A")
- self.assertAllEqual(1, self.evaluate(produced))
+ self._ExpectRead(key, value, b"A")
+ self.assertAllEqual(1, self.evaluate(produced))
- self._ExpectRead(sess, key, value, b"B")
+ self._ExpectRead(key, value, b"B")
- self._ExpectRead(sess, key, value, b"C")
- self.assertAllEqual(3, self.evaluate(produced))
- self.assertAllEqual(0, self.evaluate(queued_length))
+ self._ExpectRead(key, value, b"C")
+ self.assertAllEqual(3, self.evaluate(produced))
+ self.assertAllEqual(0, self.evaluate(queued_length))
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ self.evaluate([key, value])
- self.assertAllEqual(3, self.evaluate(work_completed))
- self.assertAllEqual(3, self.evaluate(produced))
- self.assertAllEqual(0, self.evaluate(queued_length))
+ self.assertAllEqual(3, self.evaluate(work_completed))
+ self.assertAllEqual(3, self.evaluate(produced))
+ self.assertAllEqual(0, self.evaluate(queued_length))
def testMultipleEpochs(self):
- with self.cached_session() as sess:
- reader = io_ops.IdentityReader("test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- enqueue = queue.enqueue_many([["DD", "EE"]])
- key, value = reader.read(queue)
+ reader = io_ops.IdentityReader("test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ enqueue = queue.enqueue_many([["DD", "EE"]])
+ key, value = reader.read(queue)
- enqueue.run()
- self._ExpectRead(sess, key, value, b"DD")
- self._ExpectRead(sess, key, value, b"EE")
- enqueue.run()
- self._ExpectRead(sess, key, value, b"DD")
- self._ExpectRead(sess, key, value, b"EE")
- enqueue.run()
- self._ExpectRead(sess, key, value, b"DD")
- self._ExpectRead(sess, key, value, b"EE")
- queue.close().run()
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- sess.run([key, value])
+ self.evaluate(enqueue)
+ self._ExpectRead(key, value, b"DD")
+ self._ExpectRead(key, value, b"EE")
+ self.evaluate(enqueue)
+ self._ExpectRead(key, value, b"DD")
+ self._ExpectRead(key, value, b"EE")
+ self.evaluate(enqueue)
+ self._ExpectRead(key, value, b"DD")
+ self._ExpectRead(key, value, b"EE")
+ self.evaluate(queue.close())
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ self.evaluate([key, value])
def testSerializeRestore(self):
- with self.cached_session() as sess:
- reader = io_ops.IdentityReader("test_reader")
- produced = reader.num_records_produced()
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- queue.enqueue_many([["X", "Y", "Z"]]).run()
- key, value = reader.read(queue)
+ reader = io_ops.IdentityReader("test_reader")
+ produced = reader.num_records_produced()
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ self.evaluate(queue.enqueue_many([["X", "Y", "Z"]]))
+ key, value = reader.read(queue)
- self._ExpectRead(sess, key, value, b"X")
- self.assertAllEqual(1, self.evaluate(produced))
- state = reader.serialize_state().eval()
+ self._ExpectRead(key, value, b"X")
+ self.assertAllEqual(1, self.evaluate(produced))
+ state = self.evaluate(reader.serialize_state())
- self._ExpectRead(sess, key, value, b"Y")
- self._ExpectRead(sess, key, value, b"Z")
- self.assertAllEqual(3, self.evaluate(produced))
+ self._ExpectRead(key, value, b"Y")
+ self._ExpectRead(key, value, b"Z")
+ self.assertAllEqual(3, self.evaluate(produced))
- queue.enqueue_many([["Y", "Z"]]).run()
- queue.close().run()
- reader.restore_state(state).run()
- self.assertAllEqual(1, self.evaluate(produced))
- self._ExpectRead(sess, key, value, b"Y")
- self._ExpectRead(sess, key, value, b"Z")
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- sess.run([key, value])
- self.assertAllEqual(3, self.evaluate(produced))
+ self.evaluate(queue.enqueue_many([["Y", "Z"]]))
+ self.evaluate(queue.close())
+ self.evaluate(reader.restore_state(state))
+ self.assertAllEqual(1, self.evaluate(produced))
+ self._ExpectRead(key, value, b"Y")
+ self._ExpectRead(key, value, b"Z")
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ self.evaluate([key, value])
+ self.assertAllEqual(3, self.evaluate(produced))
- self.assertEqual(bytes, type(state))
+ self.assertEqual(bytes, type(state))
- with self.assertRaises(ValueError):
- reader.restore_state([])
+ with self.assertRaises(ValueError):
+ reader.restore_state([])
- with self.assertRaises(ValueError):
- reader.restore_state([state, state])
+ with self.assertRaises(ValueError):
+ reader.restore_state([state, state])
- with self.assertRaisesOpError(
- "Could not parse state for IdentityReader 'test_reader'"):
- reader.restore_state(state[1:]).run()
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ self.evaluate(reader.restore_state(state[1:]))
- with self.assertRaisesOpError(
- "Could not parse state for IdentityReader 'test_reader'"):
- reader.restore_state(state[:-1]).run()
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ self.evaluate(reader.restore_state(state[:-1]))
- with self.assertRaisesOpError(
- "Could not parse state for IdentityReader 'test_reader'"):
- reader.restore_state(state + b"ExtraJunk").run()
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ self.evaluate(reader.restore_state(state + b"ExtraJunk"))
- with self.assertRaisesOpError(
- "Could not parse state for IdentityReader 'test_reader'"):
- reader.restore_state(b"PREFIX" + state).run()
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ self.evaluate(reader.restore_state(b"PREFIX" + state))
- with self.assertRaisesOpError(
- "Could not parse state for IdentityReader 'test_reader'"):
- reader.restore_state(b"BOGUS" + state[5:]).run()
+ with self.assertRaisesOpError(
+ "Could not parse state for IdentityReader 'test_reader'"):
+ self.evaluate(reader.restore_state(b"BOGUS" + state[5:]))
def testReset(self):
- with self.cached_session() as sess:
- reader = io_ops.IdentityReader("test_reader")
- work_completed = reader.num_work_units_completed()
- produced = reader.num_records_produced()
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- queued_length = queue.size()
- key, value = reader.read(queue)
+ reader = io_ops.IdentityReader("test_reader")
+ work_completed = reader.num_work_units_completed()
+ produced = reader.num_records_produced()
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ queued_length = queue.size()
+ key, value = reader.read(queue)
- queue.enqueue_many([["X", "Y", "Z"]]).run()
- self._ExpectRead(sess, key, value, b"X")
- self.assertLess(0, self.evaluate(queued_length))
- self.assertAllEqual(1, self.evaluate(produced))
+ self.evaluate(queue.enqueue_many([["X", "Y", "Z"]]))
+ self._ExpectRead(key, value, b"X")
+ self.assertLess(0, self.evaluate(queued_length))
+ self.assertAllEqual(1, self.evaluate(produced))
- self._ExpectRead(sess, key, value, b"Y")
- self.assertLess(0, self.evaluate(work_completed))
- self.assertAllEqual(2, self.evaluate(produced))
+ self._ExpectRead(key, value, b"Y")
+ self.assertLess(0, self.evaluate(work_completed))
+ self.assertAllEqual(2, self.evaluate(produced))
- reader.reset().run()
- self.assertAllEqual(0, self.evaluate(work_completed))
- self.assertAllEqual(0, self.evaluate(produced))
- self.assertAllEqual(1, self.evaluate(queued_length))
- self._ExpectRead(sess, key, value, b"Z")
+ self.evaluate(reader.reset())
+ self.assertAllEqual(0, self.evaluate(work_completed))
+ self.assertAllEqual(0, self.evaluate(produced))
+ self.assertAllEqual(1, self.evaluate(queued_length))
+ self._ExpectRead(key, value, b"Z")
- queue.enqueue_many([["K", "L"]]).run()
- self._ExpectRead(sess, key, value, b"K")
+ self.evaluate(queue.enqueue_many([["K", "L"]]))
+ self._ExpectRead(key, value, b"K")
class WholeFileReaderTest(test.TestCase):
@@ -301,44 +297,42 @@
os.remove(fn)
super(WholeFileReaderTest, self).tearDown()
- def _ExpectRead(self, sess, key, value, index):
- k, v = sess.run([key, value])
+ def _ExpectRead(self, key, value, index):
+ k, v = self.evaluate([key, value])
self.assertAllEqual(compat.as_bytes(self._filenames[index]), k)
self.assertAllEqual(self._content[index], v)
def testOneEpoch(self):
- with self.cached_session() as sess:
- reader = io_ops.WholeFileReader("test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- queue.enqueue_many([self._filenames]).run()
- queue.close().run()
- key, value = reader.read(queue)
+ reader = io_ops.WholeFileReader("test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ self.evaluate(queue.enqueue_many([self._filenames]))
+ self.evaluate(queue.close())
+ key, value = reader.read(queue)
- self._ExpectRead(sess, key, value, 0)
- self._ExpectRead(sess, key, value, 1)
- self._ExpectRead(sess, key, value, 2)
+ self._ExpectRead(key, value, 0)
+ self._ExpectRead(key, value, 1)
+ self._ExpectRead(key, value, 2)
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ self.evaluate([key, value])
def testInfiniteEpochs(self):
- with self.cached_session() as sess:
- reader = io_ops.WholeFileReader("test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- enqueue = queue.enqueue_many([self._filenames])
- key, value = reader.read(queue)
+ reader = io_ops.WholeFileReader("test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ enqueue = queue.enqueue_many([self._filenames])
+ key, value = reader.read(queue)
- enqueue.run()
- self._ExpectRead(sess, key, value, 0)
- self._ExpectRead(sess, key, value, 1)
- enqueue.run()
- self._ExpectRead(sess, key, value, 2)
- self._ExpectRead(sess, key, value, 0)
- self._ExpectRead(sess, key, value, 1)
- enqueue.run()
- self._ExpectRead(sess, key, value, 2)
- self._ExpectRead(sess, key, value, 0)
+ self.evaluate(enqueue)
+ self._ExpectRead(key, value, 0)
+ self._ExpectRead(key, value, 1)
+ self.evaluate(enqueue)
+ self._ExpectRead(key, value, 2)
+ self._ExpectRead(key, value, 0)
+ self._ExpectRead(key, value, 1)
+ self.evaluate(enqueue)
+ self._ExpectRead(key, value, 2)
+ self._ExpectRead(key, value, 0)
class TextLineReaderTest(test.TestCase):
@@ -366,22 +360,21 @@
return filenames
def _testOneEpoch(self, files):
- with self.cached_session() as sess:
- reader = io_ops.TextLineReader(name="test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.TextLineReader(name="test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue_many([files]).run()
- queue.close().run()
- for i in range(self._num_files):
- for j in range(self._num_lines):
- k, v = sess.run([key, value])
- self.assertAllEqual("%s:%d" % (files[i], j + 1), compat.as_text(k))
- self.assertAllEqual(self._LineText(i, j), v)
+ self.evaluate(queue.enqueue_many([files]))
+ self.evaluate(queue.close())
+ for i in range(self._num_files):
+ for j in range(self._num_lines):
+ k, v = self.evaluate([key, value])
+ self.assertAllEqual("%s:%d" % (files[i], j + 1), compat.as_text(k))
+ self.assertAllEqual(self._LineText(i, j), v)
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- k, v = sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = self.evaluate([key, value])
def testOneEpochLF(self):
self._testOneEpoch(self._CreateFiles(crlf=False))
@@ -391,22 +384,21 @@
def testSkipHeaderLines(self):
files = self._CreateFiles()
- with self.cached_session() as sess:
- reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue_many([files]).run()
- queue.close().run()
- for i in range(self._num_files):
- for j in range(self._num_lines - 1):
- k, v = sess.run([key, value])
- self.assertAllEqual("%s:%d" % (files[i], j + 2), compat.as_text(k))
- self.assertAllEqual(self._LineText(i, j + 1), v)
+ self.evaluate(queue.enqueue_many([files]))
+ self.evaluate(queue.close())
+ for i in range(self._num_files):
+ for j in range(self._num_lines - 1):
+ k, v = self.evaluate([key, value])
+ self.assertAllEqual("%s:%d" % (files[i], j + 2), compat.as_text(k))
+ self.assertAllEqual(self._LineText(i, j + 1), v)
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- k, v = sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = self.evaluate([key, value])
class FixedLengthRecordReaderTest(TFCompressionTestCase):
@@ -522,55 +514,53 @@
# gap_bytes=hop_bytes-record_bytes
def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None):
hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes
- with self.cached_session() as sess:
- reader = io_ops.FixedLengthRecordReader(
- header_bytes=self._header_bytes,
- record_bytes=self._record_bytes,
- footer_bytes=self._footer_bytes,
- hop_bytes=hop_bytes,
- encoding=encoding,
- name="test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.FixedLengthRecordReader(
+ header_bytes=self._header_bytes,
+ record_bytes=self._record_bytes,
+ footer_bytes=self._footer_bytes,
+ hop_bytes=hop_bytes,
+ encoding=encoding,
+ name="test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue_many([files]).run()
- queue.close().run()
- for i in range(self._num_files):
- for j in range(num_records):
- k, v = sess.run([key, value])
- self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
- self.assertAllEqual(self._Record(i, j), v)
+ self.evaluate(queue.enqueue_many([files]))
+ self.evaluate(queue.close())
+ for i in range(self._num_files):
+ for j in range(num_records):
+ k, v = self.evaluate([key, value])
+ self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
+ self.assertAllEqual(self._Record(i, j), v)
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- k, v = sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = self.evaluate([key, value])
def _TestOneEpochWithHopBytes(self,
files,
num_overlapped_records,
encoding=None):
- with self.cached_session() as sess:
- reader = io_ops.FixedLengthRecordReader(
- header_bytes=self._header_bytes,
- record_bytes=self._record_bytes,
- footer_bytes=self._footer_bytes,
- hop_bytes=self._hop_bytes,
- encoding=encoding,
- name="test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.FixedLengthRecordReader(
+ header_bytes=self._header_bytes,
+ record_bytes=self._record_bytes,
+ footer_bytes=self._footer_bytes,
+ hop_bytes=self._hop_bytes,
+ encoding=encoding,
+ name="test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue_many([files]).run()
- queue.close().run()
- for i in range(self._num_files):
- for j in range(num_overlapped_records):
- k, v = sess.run([key, value])
- self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
- self.assertAllEqual(self._OverlappedRecord(i, j), v)
+ self.evaluate(queue.enqueue_many([files]))
+ self.evaluate(queue.close())
+ for i in range(self._num_files):
+ for j in range(num_overlapped_records):
+ k, v = self.evaluate([key, value])
+ self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
+ self.assertAllEqual(self._OverlappedRecord(i, j), v)
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- k, v = sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = self.evaluate([key, value])
def testOneEpoch(self):
for num_records in [0, 7]:
@@ -621,84 +611,80 @@
def testOneEpoch(self):
files = self._CreateFiles()
- with self.cached_session() as sess:
- reader = io_ops.TFRecordReader(name="test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.TFRecordReader(name="test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue_many([files]).run()
- queue.close().run()
- for i in range(self._num_files):
- for j in range(self._num_records):
- k, v = sess.run([key, value])
- self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
- self.assertAllEqual(self._Record(i, j), v)
+ self.evaluate(queue.enqueue_many([files]))
+ self.evaluate(queue.close())
+ for i in range(self._num_files):
+ for j in range(self._num_records):
+ k, v = self.evaluate([key, value])
+ self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
+ self.assertAllEqual(self._Record(i, j), v)
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- k, v = sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = self.evaluate([key, value])
def testReadUpTo(self):
files = self._CreateFiles()
- with self.cached_session() as sess:
- reader = io_ops.TFRecordReader(name="test_reader")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- batch_size = 3
- key, value = reader.read_up_to(queue, batch_size)
+ reader = io_ops.TFRecordReader(name="test_reader")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ batch_size = 3
+ key, value = reader.read_up_to(queue, batch_size)
- queue.enqueue_many([files]).run()
- queue.close().run()
- num_k = 0
- num_v = 0
+ self.evaluate(queue.enqueue_many([files]))
+ self.evaluate(queue.close())
+ num_k = 0
+ num_v = 0
- while True:
- try:
- k, v = sess.run([key, value])
- # Test reading *up to* batch_size records
- self.assertLessEqual(len(k), batch_size)
- self.assertLessEqual(len(v), batch_size)
- num_k += len(k)
- num_v += len(v)
- except errors_impl.OutOfRangeError:
- break
+ while True:
+ try:
+ k, v = self.evaluate([key, value])
+ # Test reading *up to* batch_size records
+ self.assertLessEqual(len(k), batch_size)
+ self.assertLessEqual(len(v), batch_size)
+ num_k += len(k)
+ num_v += len(v)
+ except errors_impl.OutOfRangeError:
+ break
- # Test that we have read everything
- self.assertEqual(self._num_files * self._num_records, num_k)
- self.assertEqual(self._num_files * self._num_records, num_v)
+ # Test that we have read everything
+ self.assertEqual(self._num_files * self._num_records, num_k)
+ self.assertEqual(self._num_files * self._num_records, num_v)
def testReadZlibFiles(self):
options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
files = self._CreateFiles(options)
- with self.cached_session() as sess:
- reader = io_ops.TFRecordReader(name="test_reader", options=options)
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.TFRecordReader(name="test_reader", options=options)
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue_many([files]).run()
- queue.close().run()
- for i in range(self._num_files):
- for j in range(self._num_records):
- k, v = sess.run([key, value])
- self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
- self.assertAllEqual(self._Record(i, j), v)
+ self.evaluate(queue.enqueue_many([files]))
+ self.evaluate(queue.close())
+ for i in range(self._num_files):
+ for j in range(self._num_records):
+ k, v = self.evaluate([key, value])
+ self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
+ self.assertAllEqual(self._Record(i, j), v)
def testReadGzipFiles(self):
options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
files = self._CreateFiles(options)
- with self.cached_session() as sess:
- reader = io_ops.TFRecordReader(name="test_reader", options=options)
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.TFRecordReader(name="test_reader", options=options)
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue_many([files]).run()
- queue.close().run()
- for i in range(self._num_files):
- for j in range(self._num_records):
- k, v = sess.run([key, value])
- self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
- self.assertAllEqual(self._Record(i, j), v)
+ self.evaluate(queue.enqueue_many([files]))
+ self.evaluate(queue.close())
+ for i in range(self._num_files):
+ for j in range(self._num_records):
+ k, v = self.evaluate([key, value])
+ self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
+ self.assertAllEqual(self._Record(i, j), v)
class AsyncReaderTest(test.TestCase):
@@ -733,7 +719,7 @@
fname = os.path.join(self.get_temp_dir(), "deadlock.%s.txt" % i)
with open(fname, "wb") as f:
f.write(("file-%s" % i).encode())
- d.queue.enqueue_many([[fname]]).run()
+ self.evaluate(d.queue.enqueue_many([[fname]]))
d.thread.join()
self.assertEqual([[("file-%s" % i).encode()]], d.output)
@@ -752,22 +738,21 @@
shutil.copy(path, self.db_path)
def testReadFromFile(self):
- with self.cached_session() as sess:
- reader = io_ops.LMDBReader(name="test_read_from_file")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.LMDBReader(name="test_read_from_file")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue([self.db_path]).run()
- queue.close().run()
- for i in range(10):
- k, v = sess.run([key, value])
- self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
- self.assertAllEqual(
- compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
+ self.evaluate(queue.enqueue([self.db_path]))
+ self.evaluate(queue.close())
+ for i in range(10):
+ k, v = self.evaluate([key, value])
+ self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
+ self.assertAllEqual(
+ compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- k, v = sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = self.evaluate([key, value])
def testReadFromSameFile(self):
with self.cached_session() as sess:
@@ -782,29 +767,28 @@
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
for _ in range(3):
for _ in range(10):
- k1, v1, k2, v2 = sess.run([key1, value1, key2, value2])
+ k1, v1, k2, v2 = self.evaluate([key1, value1, key2, value2])
self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2))
self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2))
coord.request_stop()
coord.join(threads)
def testReadFromFolder(self):
- with self.cached_session() as sess:
- reader = io_ops.LMDBReader(name="test_read_from_folder")
- queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
- key, value = reader.read(queue)
+ reader = io_ops.LMDBReader(name="test_read_from_folder")
+ queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+ key, value = reader.read(queue)
- queue.enqueue([self.db_path]).run()
- queue.close().run()
- for i in range(10):
- k, v = sess.run([key, value])
- self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
- self.assertAllEqual(
- compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
+ self.evaluate(queue.enqueue([self.db_path]))
+ self.evaluate(queue.close())
+ for i in range(10):
+ k, v = self.evaluate([key, value])
+ self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
+ self.assertAllEqual(
+ compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
- with self.assertRaisesOpError("is closed and has insufficient elements "
- "\\(requested 1, current size 0\\)"):
- k, v = sess.run([key, value])
+ with self.assertRaisesOpError("is closed and has insufficient elements "
+ "\\(requested 1, current size 0\\)"):
+ k, v = self.evaluate([key, value])
def testReadFromFileRepeatedly(self):
with self.cached_session() as sess:
@@ -819,7 +803,7 @@
for _ in range(3):
# Go over all 10 records each time.
for j in range(10):
- k, v = sess.run([key, value])
+ k, v = self.evaluate([key, value])
self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j)))
self.assertAllEqual(
compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j))))
diff --git a/tensorflow/python/kernel_tests/reduce_benchmark_test.py b/tensorflow/python/kernel_tests/reduce_benchmark_test.py
index 3a2fb81..ef9c4c3 100644
--- a/tensorflow/python/kernel_tests/reduce_benchmark_test.py
+++ b/tensorflow/python/kernel_tests/reduce_benchmark_test.py
@@ -81,7 +81,7 @@
grad, = gradients_impl.gradients(reduction, tensor)
def fn():
- sess.run(grad.op)
+ self.evaluate(grad.op)
self._run(fn, 10000)
@@ -98,7 +98,7 @@
grad, = gradients_impl.gradients(reduction, tensor)
def fn():
- sess.run(grad.op)
+ self.evaluate(grad.op)
self._run(fn, 10000)
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py
index 612b2c5..4eb3297 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test.py
@@ -238,7 +238,7 @@
with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_row_sum = self._tf_reduce(arr, 1, False)
tf_col_sum = self._tf_reduce(arr, 0, False)
- tf_out_row, tf_out_col = sess.run([tf_row_sum, tf_col_sum])
+ tf_out_row, tf_out_col = self.evaluate([tf_row_sum, tf_col_sum])
self.assertAllClose(col_sum, tf_out_col)
self.assertAllClose(row_sum, tf_out_row)
@@ -252,7 +252,7 @@
with self.session(graph=ops.Graph(), use_gpu=True) as sess:
tf_sum_xz = self._tf_reduce(arr, [0, 2], False)
tf_sum_y = self._tf_reduce(arr, 1, False)
- tf_out_sum_xz, tf_out_sum_y = sess.run([tf_sum_xz, tf_sum_y])
+ tf_out_sum_xz, tf_out_sum_y = self.evaluate([tf_sum_xz, tf_sum_y])
self.assertAllClose(sum_y, tf_out_sum_y)
self.assertAllClose(sum_xz, tf_out_sum_xz)
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 68243f2..30cef90 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -147,7 +147,7 @@
# Repeat the experiment for 100 times. All tensor shapes and its tensor
# values are randomly generated for each run.
for _ in xrange(100):
- dx_f32_v, dx_f16_v = sess.run([dx_f32, dx_f16])
+ dx_f32_v, dx_f16_v = self.evaluate([dx_f32, dx_f16])
self.assertAllClose(dx_f32_v, dx_f16_v, atol=3e-4)
def testGradientFloat64(self):
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index c351a18..3056309 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -456,7 +456,7 @@
# TODO(alive): get this to work in Eager mode.
def testGPU(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
abc = variable_scope.get_variable(
"abc",
shape=[1],
@@ -590,11 +590,11 @@
with ops.Graph().as_default(), self.cached_session() as sess:
# v describes a VariableDef-based variable without an initial value.
v = resource_variable_ops.ResourceVariable(variable_def=v_def)
- self.assertEqual(3.0, sess.run(v.initialized_value()))
+ self.assertEqual(3.0, self.evaluate(v.initialized_value()))
# initialized_value should not rerun the initializer_op if the variable
# has already been initialized elsewhere.
- sess.run(v.assign(1.0))
+ self.evaluate(v.assign(1.0))
self.assertEqual(1.0, v.initialized_value().eval())
v_def.ClearField("initial_value_name")
@@ -606,7 +606,7 @@
self.assertProtoEquals(v_def, v.to_proto())
# But attempts to use initialized_value will result in errors.
with self.assertRaises(ValueError):
- sess.run(v.initialized_value())
+ self.evaluate(v.initialized_value())
def testTrainableInProto(self):
with ops.Graph().as_default():
diff --git a/tensorflow/python/kernel_tests/save_restore_ops_test.py b/tensorflow/python/kernel_tests/save_restore_ops_test.py
index cb9aa1e..be117c4 100644
--- a/tensorflow/python/kernel_tests/save_restore_ops_test.py
+++ b/tensorflow/python/kernel_tests/save_restore_ops_test.py
@@ -17,14 +17,30 @@
from __future__ import division
from __future__ import print_function
+import os
+
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import test
+class SaveTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testRelativePath(self):
+ os.chdir(self.get_temp_dir())
+ self.evaluate(io_ops.save_v2(
+ "ckpt", ["x"], [""], [constant_op.constant(100.)]))
+ self.assertAllEqual([100.],
+ self.evaluate(io_ops.restore_v2(
+ "ckpt", ["x"], [""], [dtypes.float32])))
+
+
class ShardedFileOpsTest(test.TestCase):
def testShardedFileName(self):
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index 1f12497..c388121 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -161,7 +161,7 @@
init = variables.global_variables_initializer()
with self.session(use_gpu=True) as sess:
- sess.run(init)
+ self.evaluate(init)
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
@@ -175,8 +175,8 @@
init = variables.global_variables_initializer()
with self.session(use_gpu=True) as sess:
- sess.run(init)
- sess.run(scatter)
+ self.evaluate(init)
+ self.evaluate(scatter)
self.assertAllClose(ref.eval(), expected)
def testSimple2(self):
@@ -189,7 +189,7 @@
init = variables.global_variables_initializer()
with self.session(use_gpu=True) as sess:
- sess.run(init)
+ self.evaluate(init)
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
@@ -203,7 +203,7 @@
init = variables.global_variables_initializer()
with self.session(use_gpu=True) as sess:
- sess.run(init)
+ self.evaluate(init)
result = self.evaluate(scatter)
self.assertAllClose(result, expected)
@@ -341,7 +341,7 @@
init = variables.global_variables_initializer()
with session.Session() as sess:
- sess.run(init)
+ self.evaluate(init)
result = self.evaluate(scatter)
assert np.allclose(result, expected_result)
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index 1c7006a..a4daad7 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -22,6 +22,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -320,19 +321,19 @@
updates = np.array([-3, -4, -5]).astype(np.float32)
# With GPU, the code ignores indices that are out of range.
# We don't test the implementation; just test there's no failures.
- with self.cached_session(force_gpu=True):
+ with test_util.force_gpu():
ref = variables.Variable(params)
ref.initializer.run()
# Indices all in range, no problem.
indices = np.array([2, 0, 5])
- op(ref, indices, updates).eval()
+ self.evaluate(op(ref, indices, updates))
# Indicies out of range should not fail.
indices = np.array([-1, 0, 5])
- op(ref, indices, updates).eval()
+ self.evaluate(op(ref, indices, updates))
indices = np.array([2, 0, 6])
- op(ref, indices, updates).eval()
+ self.evaluate(op(ref, indices, updates))
if __name__ == '__main__':
diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
index 8ca8e9d..42577f7 100644
--- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
+++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
@@ -81,7 +81,7 @@
self.assertEqual(matrix.shape, (32, 32))
matrix_tensor = constant_op.constant(matrix)
with self.session(use_gpu=True) as sess:
- (e, v) = sess.run(linalg_ops.self_adjoint_eig(matrix_tensor))
+ (e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor))
self.assertEqual(e.size, 32)
self.assertAllClose(
np.matmul(v, v.transpose()), np.eye(32, dtype=np.float32), atol=2e-3)
diff --git a/tensorflow/python/kernel_tests/session_ops_test.py b/tensorflow/python/kernel_tests/session_ops_test.py
index 73d85dd..dc663cb 100644
--- a/tensorflow/python/kernel_tests/session_ops_test.py
+++ b/tensorflow/python/kernel_tests/session_ops_test.py
@@ -64,7 +64,7 @@
c = math_ops.multiply(a, b)
h = session_ops.get_session_handle(c)
v = math_ops.multiply(a, c)
- h, v = sess.run([h, v])
+ h, v = self.evaluate([h, v])
self.assertEqual(50, h.eval())
self.assertEqual(500, v)
@@ -77,7 +77,7 @@
p = math_ops.less(a, b)
c = math_ops.multiply(a, b)
h = session_ops.get_session_handle(c)
- p, h = sess.run([p, h])
+ p, h = self.evaluate([p, h])
# Run by feeding a tensor handle.
f, x = session_ops.get_session_tensor(h.handle, dtypes.int32)
@@ -154,7 +154,7 @@
b = constant_op.constant(5)
c = math_ops.multiply(a, b)
h = session_ops.get_session_handle(c)
- sess.run(h).delete()
+ self.evaluate(h).delete()
def testHandleDeleteRaw(self):
with self.cached_session() as sess:
@@ -174,10 +174,10 @@
with self.cached_session() as sess:
with ops.device(test.gpu_device_name()):
a = constant_op.constant(1.0)
- a_handle = sess.run(session_ops.get_session_handle(a))
+ a_handle = self.evaluate(session_ops.get_session_handle(a))
with ops.device("/cpu:0"):
b = constant_op.constant(2.0)
- b_handle = sess.run(session_ops.get_session_handle(b))
+ b_handle = self.evaluate(session_ops.get_session_handle(b))
a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32)
b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32)
@@ -193,8 +193,8 @@
# initial values live on CPU
with ops.device("/cpu:0"):
one = constant_op.constant(1, dtype=dtypes.float32)
- one_handle = sess.run(session_ops.get_session_handle(one))
- x_handle = sess.run(session_ops.get_session_handle(one))
+ one_handle = self.evaluate(session_ops.get_session_handle(one))
+ x_handle = self.evaluate(session_ops.get_session_handle(one))
# addition lives on GPU
with ops.device(test.gpu_device_name()):
@@ -239,7 +239,7 @@
c = math_ops.multiply(a, b)
d = math_ops.multiply(c, c)
- h_c = sess.run(session_ops.get_session_handle(c))
+ h_c = self.evaluate(session_ops.get_session_handle(c))
self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c}))
@@ -248,7 +248,7 @@
a = constant_op.constant(10.0)
b = constant_op.constant(5.0)
c = math_ops.multiply(a, b)
- h_c = sess.run(session_ops.get_session_handle(c))
+ h_c = self.evaluate(session_ops.get_session_handle(c))
d = array_ops.identity(c)
c_val = sess.run(c, feed_dict={c: h_c})
@@ -277,8 +277,8 @@
d = math_ops.div(a, b)
e = math_ops.subtract(c, d)
- h_c = sess.run(session_ops.get_session_handle(c))
- h_d = sess.run(session_ops.get_session_handle(d))
+ h_c = self.evaluate(session_ops.get_session_handle(c))
+ h_d = self.evaluate(session_ops.get_session_handle(d))
self.assertAllClose(48.0, sess.run(e, feed_dict={c: h_c, d: h_d}))
self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c}))
@@ -294,7 +294,7 @@
self.assertAllClose(12.0, self.evaluate(a))
self.assertAllClose(17.0, sess.run(b, feed_dict={a: h_a_read}))
- sess.run(inc_a)
+ self.evaluate(inc_a)
self.assertAllClose(19.0, sess.run(b, feed_dict={a: h_a_read}))
diff --git a/tensorflow/python/kernel_tests/sets_test.py b/tensorflow/python/kernel_tests/sets_test.py
index e037f51..ba3d32b 100644
--- a/tensorflow/python/kernel_tests/sets_test.py
+++ b/tensorflow/python/kernel_tests/sets_test.py
@@ -534,7 +534,7 @@
def _set_intersection_count(self, a, b):
op = sets.set_size(sets.set_intersection(a, b))
with self.cached_session() as sess:
- return sess.run(op)
+ return self.evaluate(op)
def test_set_difference_multirow_2d(self):
for dtype in _DTYPES:
@@ -972,7 +972,7 @@
def _set_difference_count(self, a, b, aminusb=True):
op = sets.set_size(sets.set_difference(a, b, aminusb))
with self.cached_session() as sess:
- return sess.run(op)
+ return self.evaluate(op)
def test_set_union_multirow_2d(self):
for dtype in _DTYPES:
@@ -1221,7 +1221,7 @@
def _set_union_count(self, a, b):
op = sets.set_size(sets.set_union(a, b))
with self.cached_session() as sess:
- return sess.run(op)
+ return self.evaluate(op)
def _assert_set_operation(self, expected_indices, expected_values,
expected_shape, sparse_tensor_value, dtype):
diff --git a/tensorflow/python/kernel_tests/signal/spectral_ops_test.py b/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
index 7583c4d..7b9748c 100644
--- a/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
@@ -235,7 +235,8 @@
inverse_window = inverse_window_fn(frame_length, dtype=dtypes.float32)
with self.cached_session(use_gpu=True) as sess:
- hann_window, inverse_window = sess.run([hann_window, inverse_window])
+ hann_window, inverse_window = self.evaluate(
+ [hann_window, inverse_window])
# Expect unit gain at each phase of the window.
product_window = hann_window * inverse_window
@@ -263,7 +264,8 @@
inverse_window = inverse_window_fn(frame_length, dtype=dtypes.float32)
with self.cached_session(use_gpu=True) as sess:
- hann_window, inverse_window = sess.run([hann_window, inverse_window])
+ hann_window, inverse_window = self.evaluate(
+ [hann_window, inverse_window])
self.assertAllClose(hann_window, inverse_window * 1.5)
@@ -293,7 +295,7 @@
# the sum of the magnitude STFT.
sinusoid = math_ops.sin(
2 * np.pi * math_ops.linspace(0.0, 1.0, signal_length))
- sinusoid_gradient = sess.run(self._compute_stft_gradient(sinusoid))
+ sinusoid_gradient = self.evaluate(self._compute_stft_gradient(sinusoid))
self.assertFalse((sinusoid_gradient == 0.0).all())
def test_gradients_numerical(self):
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index 5bb34a6..ee48c6e 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -207,7 +207,7 @@
dtype=dtypes.float32)
slice_t = array_ops.slice(a, [0, 0], [2, 2])
slice2_t = a[:2, :2]
- slice_val, slice2_val = sess.run([slice_t, slice2_t])
+ slice_val, slice2_val = self.evaluate([slice_t, slice2_t])
self.assertAllEqual(slice_val, inp[:2, :2])
self.assertAllEqual(slice2_val, inp[:2, :2])
self.assertEqual(slice_val.shape, slice_t.get_shape())
@@ -247,7 +247,7 @@
+ sizes[3], indices[4]:indices[4] + sizes[4], indices[5]:
indices[5] + sizes[5]]
- slice_val, slice2_val = sess.run([slice_t, slice2_t])
+ slice_val, slice2_val = self.evaluate([slice_t, slice2_t])
expected_val = inp[indices[0]:indices[0] + sizes[0], indices[1]:indices[
1] + sizes[1], indices[2]:indices[2] + sizes[2], indices[3]:indices[
@@ -313,7 +313,7 @@
g1 = gradients_impl.gradients(loss1, x)[0]
g2 = gradients_impl.gradients(loss2, x)[0]
- g1_val, g2_val = sess.run([g1, g2])
+ g1_val, g2_val = self.evaluate([g1, g2])
self.assertAllEqual(g1_val, g2_val)
def testGradientsAll(self):
diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
index c32b4ff..c9aaa68 100644
--- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
@@ -36,21 +36,22 @@
def _testOne(self, inputs, block_size, outputs, dtype=dtypes.float32):
input_nhwc = math_ops.cast(inputs, dtype)
- with self.session(use_gpu=False):
+ with test_util.force_cpu():
# test NHWC (default) on CPU
x_tf = array_ops.space_to_depth(input_nhwc, block_size)
- self.assertAllEqual(x_tf.eval(), outputs)
- if test.is_gpu_available():
- with self.session(force_gpu=True):
+ self.assertAllEqual(self.evaluate(x_tf), outputs)
+
+ if test_util.is_gpu_available():
+ with test_util.force_gpu():
# test NHWC (default) on GPU
x_tf = array_ops.space_to_depth(input_nhwc, block_size)
- self.assertAllEqual(x_tf.eval(), outputs)
+ self.assertAllEqual(self.evaluate(x_tf), outputs)
# test NCHW on GPU
input_nchw = test_util.NHWCToNCHW(input_nhwc)
output_nchw = array_ops.space_to_depth(
input_nchw, block_size, data_format="NCHW")
output_nhwc = test_util.NCHWToNHWC(output_nchw)
- self.assertAllEqual(output_nhwc.eval(), outputs)
+ self.assertAllEqual(self.evaluate(output_nhwc), outputs)
def testBasic(self):
x_np = [[[[1], [2]], [[3], [4]]]]
@@ -134,13 +135,14 @@
input_nhwc = array_ops.ones([batch_size, 4, 6, 3])
x_out = array_ops.ones([batch_size, 2, 3, 12])
- with self.session(use_gpu=False):
+ with test_util.force_cpu():
# test NHWC (default) on CPU
x_tf = array_ops.space_to_depth(input_nhwc, block_size)
self.assertAllEqual(x_tf.shape, x_out.shape)
self.evaluate(x_tf)
+
if test.is_gpu_available():
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
# test NHWC (default) on GPU
x_tf = array_ops.space_to_depth(input_nhwc, block_size)
self.assertAllEqual(x_tf.shape, x_out.shape)
@@ -271,7 +273,7 @@
actual = array_ops.space_to_depth(t, block_size, data_format=data_format)
with self.cached_session(use_gpu=use_gpu) as sess:
- actual_vals, expected_vals = sess.run([actual, expected])
+ actual_vals, expected_vals = self.evaluate([actual, expected])
self.assertTrue(np.array_equal(actual_vals, expected_vals))
def testAgainstTranspose(self):
diff --git a/tensorflow/python/kernel_tests/sparse_add_op_test.py b/tensorflow/python/kernel_tests/sparse_add_op_test.py
index 845950b..c61f863 100644
--- a/tensorflow/python/kernel_tests/sparse_add_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_add_op_test.py
@@ -28,6 +28,7 @@
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
@@ -85,7 +86,7 @@
constant_op.constant(shape, dtypes.int64))
def testAddSelf(self):
- with self.session(use_gpu=False) as sess:
+ with test_util.force_cpu():
for sp_a in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
for sp_b in (self._SparseTensorValue_3x3(), self._SparseTensor_3x3()):
sp_sum = sparse_ops.sparse_add(sp_a, sp_b)
@@ -99,7 +100,7 @@
self.assertAllEqual(sum_out.dense_shape, [3, 3])
def testAddSelfAndNegation(self):
- with self.session(use_gpu=False) as sess:
+ with test_util.force_cpu():
sp_a = self._SparseTensor_3x3()
sp_b = self._SparseTensor_3x3(negate=True)
@@ -112,7 +113,7 @@
self.assertAllEqual(sum_out.dense_shape, [3, 3])
def testSmallValuesShouldVanish(self):
- with self.session(use_gpu=False) as sess:
+ with test_util.force_cpu():
sp_a = self._SparseTensor_3x3()
sp_b = self._SparseTensor_3x3_v2()
@@ -147,7 +148,7 @@
sp_a, nnz_a = self._randomTensor([n, m], np.float32)
sp_b, nnz_b = self._randomTensor([n, m], np.float32)
sp_sum = sparse_ops.sparse_add(sp_a, sp_b)
- nnz_sum = len(sp_sum.values.eval())
+ nnz_sum = len(self.evaluate(sp_sum.values))
err = gradient_checker.compute_gradient_error(
[sp_a.values, sp_b.values], [(nnz_a,), (nnz_b,)], sp_sum.values,
@@ -162,16 +163,16 @@
rand_vals_np = np.random.randn(n, m).astype(dtype)
dense_np = np.random.randn(n, m).astype(dtype)
- with self.cached_session(use_gpu=False):
+ with test_util.force_cpu():
sparse, unused_nnz = _sparsify(rand_vals_np, index_dtype=index_dtype)
- s = sparse_ops.sparse_add(sparse,
- constant_op.constant(dense_np)).eval()
+ s = self.evaluate(
+ sparse_ops.sparse_add(sparse, constant_op.constant(dense_np)))
self.assertAllEqual(dense_np + rand_vals_np, s)
self.assertTrue(s.dtype == dtype)
# check commutativity
- s = sparse_ops.sparse_add(constant_op.constant(dense_np),
- sparse).eval()
+ s = self.evaluate(
+ sparse_ops.sparse_add(constant_op.constant(dense_np), sparse))
self.assertAllEqual(dense_np + rand_vals_np, s)
self.assertTrue(s.dtype == dtype)
@@ -191,7 +192,7 @@
self.assertLess(err, 1e-3)
def testInvalidSparseTensor(self):
- with self.session(use_gpu=False) as sess:
+ with test_util.force_cpu():
shape = [2, 2]
val = [0]
dense = constant_op.constant(np.zeros(shape, dtype=np.int32))
@@ -205,7 +206,7 @@
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"invalid index"):
- sess.run(s)
+ self.evaluate(s)
######################## Benchmarking code
diff --git a/tensorflow/python/kernel_tests/sparse_concat_op_test.py b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
index a3d136c..368a533 100644
--- a/tensorflow/python/kernel_tests/sparse_concat_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_concat_op_test.py
@@ -287,7 +287,7 @@
# Shape mismatches can only be caught when the op is run
with self.assertRaisesOpError("Input shapes must match"):
- sess.run(sp_concat)
+ self.evaluate(sp_concat)
def testMismatchedShapesExpandNonconcatDim(self):
with self.session(use_gpu=False) as sess:
diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
index 267275e..66589fa 100644
--- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
+++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py
@@ -140,7 +140,7 @@
t = _indexedslice(mat_to_add)
q.apply_indexed_slices_grad(t).run()
- result = sess.run(q.take_indexed_slices_grad(1))
+ result = self.evaluate(q.take_indexed_slices_grad(1))
self._assertEqual_nparray(sum_elems / len(elems), result, sess)
@@ -381,7 +381,7 @@
self.evaluate(accum_op)
def take_grad():
- results.append(sess.run(takeg_t))
+ results.append(self.evaluate(takeg_t))
accum_thread = self.checkedThread(target=apply_indexed_slices_grad)
takeg_thread = self.checkedThread(target=take_grad)
@@ -585,7 +585,7 @@
np.float32)).run()
# After take grad, constraints on accumulated gradient are removed
- sess.run(q.take_grad(1))
+ self.evaluate(q.take_grad(1))
# First successful gradient imposes new constraints.
# Hereafter, shape will additionally constrained to [None,2,2,3]
@@ -615,7 +615,7 @@
grad_values=np.array(
[[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]).astype(np.float32)).run()
- val = sess.run(q.take_indexed_slices_grad(1))
+ val = self.evaluate(q.take_indexed_slices_grad(1))
self.assertAllEqual(val.dense_shape, [2, 2, 2, 2])
q = data_flow_ops.SparseConditionalAccumulator(
@@ -627,7 +627,7 @@
[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]).astype(
np.float32)).run()
- val = sess.run(q.take_indexed_slices_grad(1))
+ val = self.evaluate(q.take_indexed_slices_grad(1))
self.assertAllEqual(val.dense_shape, [-1, 2, 2, 3])
def testApplyGradtInt32IndicesAndShape(self):
@@ -653,7 +653,7 @@
accum_op.run()
self.assertEqual(q.num_accumulated().eval(), 2)
- val = sess.run(q.take_indexed_slices_grad(1))
+ val = self.evaluate(q.take_indexed_slices_grad(1))
self.assertAllEqual(val.indices, [0, 2])
self.assertAllEqual(val.values, [[0, 0, 1], [3, 0, 4]])
self.assertAllEqual(val.dense_shape, [3, 3])
diff --git a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
index e63ba8f..538e7c6 100644
--- a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
@@ -166,7 +166,7 @@
with self.assertRaisesOpError(
r"Inconsistent rank across SparseTensors: rank prior to "
r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"):
- sess.run(sp_roundtrip)
+ self.evaluate(sp_roundtrip)
def testTakeManyFailsWrongInputOp(self):
with self.session(use_gpu=False) as sess:
@@ -178,7 +178,7 @@
sparse_map_op=handle.op, sparse_handles=[handle_value, bad_handle])
with self.assertRaisesOpError(r"Unable to find SparseTensor: 10"):
- sess.run(sp_roundtrip)
+ self.evaluate(sp_roundtrip)
class BenchmarkSparseTensorsMapVsSerialization(test.Benchmark):
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index 3f91131..cc8c7c2 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -66,7 +66,7 @@
with self.cached_session(use_gpu=True) as sess:
loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
np_features, np_labels)
- tf_loss, tf_backprop = sess.run([loss, backprop])
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
@@ -76,7 +76,7 @@
loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
np.array([[1.], [-1.], [0.]]).astype(np.float32),
np.array([0, 0, 0]).astype(label_dtype))
- tf_loss, tf_backprop = sess.run([loss, backprop])
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
self.assertAllClose([[0.0], [0.0], [0.0]], tf_backprop)
@@ -90,7 +90,7 @@
loss, backprop = (
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
features, labels))
- tf_loss, tf_backprop = sess.run([loss, backprop])
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllClose(
[[np.nan] * 4, [0.25, 0.25, 0.25, -0.75],
[-0.968, 0.087, 0.237, 0.6439], [np.nan] * 4],
@@ -104,7 +104,7 @@
loss, backprop = (
gen_nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels))
with self.assertRaisesOpError("Received a label value of"):
- sess.run([loss, backprop])
+ self.evaluate([loss, backprop])
def testNpXent(self):
# We create 2 batches of logits for testing.
@@ -226,7 +226,7 @@
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=features)
backprop = loss.op.inputs[0].op.outputs[1]
- tf_loss, tf_backprop = sess.run([loss, backprop])
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
diff --git a/tensorflow/python/kernel_tests/stack_ops_test.py b/tensorflow/python/kernel_tests/stack_ops_test.py
index 6c6fe8a..dffb260 100644
--- a/tensorflow/python/kernel_tests/stack_ops_test.py
+++ b/tensorflow/python/kernel_tests/stack_ops_test.py
@@ -131,7 +131,7 @@
pop1 = gen_data_flow_ops.stack_pop_v2(h1, dtypes.float32)
pop2 = gen_data_flow_ops.stack_pop_v2(h2, dtypes.float32)
- out1, out2 = sess.run([pop1, pop2])
+ out1, out2 = self.evaluate([pop1, pop2])
self.assertAllClose(out1, 4.0)
self.assertAllClose(out2, 5.0)
@@ -144,7 +144,7 @@
h = gen_data_flow_ops.stack_v2(
-1, elem_type=dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops.stack_close_v2(h)
- sess.run(c1)
+ self.evaluate(c1)
def testCloseStack(self):
self._testCloseStack(use_gpu=False)
@@ -157,7 +157,7 @@
c = gen_data_flow_ops.stack_push_v2(h, [[4.0, 5.0]])
with ops.control_dependencies([c]):
c1 = gen_data_flow_ops.stack_close_v2(h)
- sess.run(c1)
+ self.evaluate(c1)
def testPushCloseStack(self):
self._testPushCloseStack(use_gpu=False)
@@ -263,7 +263,7 @@
with self.cached_session(use_gpu=use_gpu) as sess:
h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops.stack_close(h)
- sess.run(c1)
+ self.evaluate(c1)
def testCloseStack(self):
self._testCloseStack(use_gpu=False)
@@ -275,7 +275,7 @@
c = gen_data_flow_ops.stack_push(h, [[4.0, 5.0]])
with ops.control_dependencies([c]):
c1 = gen_data_flow_ops.stack_close(h)
- sess.run(c1)
+ self.evaluate(c1)
def testPushCloseStack(self):
self._testPushCloseStack(use_gpu=False)
diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py
index b1e7ce5..b814843 100644
--- a/tensorflow/python/kernel_tests/stage_op_test.py
+++ b/tensorflow/python/kernel_tests/stage_op_test.py
@@ -152,11 +152,11 @@
with self.session(use_gpu=True, graph=G) as sess:
sess.run(stage, feed_dict={x: -1})
- self.assertEqual(self.evaluate(size), 1)
+ self.assertEqual(sess.run(size), 1)
sess.run(stage, feed_dict={x: -1})
- self.assertEqual(self.evaluate(size), 2)
+ self.assertEqual(sess.run(size), 2)
sess.run(clear)
- self.assertEqual(self.evaluate(size), 0)
+ self.assertEqual(sess.run(size), 0)
def testCapacity(self):
capacity = 3
@@ -210,14 +210,14 @@
capacity))
# Should have capacity elements in the staging area
- self.assertTrue(self.evaluate(size) == capacity)
+ self.assertTrue(sess.run(size) == capacity)
# Clear the staging area completely
for i in range(n):
- self.assertTrue(self.evaluate(ret) == [i])
+ self.assertTrue(sess.run(ret) == [i])
# It should now be empty
- self.assertTrue(self.evaluate(size) == 0)
+ self.assertTrue(sess.run(size) == 0)
def testMemoryLimit(self):
memory_limit = 512 * 1024 # 512K
@@ -274,13 +274,13 @@
capacity))
# Should have capacity elements in the staging area
- self.assertTrue(self.evaluate(size) == capacity)
+ self.assertTrue(sess.run(size) == capacity)
# Clear the staging area completely
for i in range(n):
- self.assertTrue(np.all(self.evaluate(ret)[0] == i))
+ self.assertTrue(np.all(sess.run(ret)[0] == i))
- self.assertTrue(self.evaluate(size) == 0)
+ self.assertTrue(sess.run(size) == 0)
if __name__ == '__main__':
diff --git a/tensorflow/python/kernel_tests/string_length_op_test.py b/tensorflow/python/kernel_tests/string_length_op_test.py
index 0c68f0c..06bf28e 100644
--- a/tensorflow/python/kernel_tests/string_length_op_test.py
+++ b/tensorflow/python/kernel_tests/string_length_op_test.py
@@ -43,9 +43,9 @@
utf8_char_lengths = string_ops.string_length(
utf8_strings, unit="UTF8_CHAR")
self.assertAllEqual(
- sess.run(utf8_byte_lengths), expected_utf8_byte_lengths)
+ self.evaluate(utf8_byte_lengths), expected_utf8_byte_lengths)
self.assertAllEqual(
- sess.run(utf8_char_lengths), expected_utf8_char_lengths)
+ self.evaluate(utf8_char_lengths), expected_utf8_char_lengths)
with self.assertRaisesRegexp(
ValueError, "Attr 'unit' of 'StringLength' Op passed string 'XYZ' "
'not in: "BYTE", "UTF8_CHAR"'):
diff --git a/tensorflow/python/kernel_tests/summary_v1_tensor_op_test.py b/tensorflow/python/kernel_tests/summary_v1_tensor_op_test.py
index 71251f5..b8e5b5b 100644
--- a/tensorflow/python/kernel_tests/summary_v1_tensor_op_test.py
+++ b/tensorflow/python/kernel_tests/summary_v1_tensor_op_test.py
@@ -50,7 +50,7 @@
with ops.name_scope("zod"):
s3 = summary_lib.tensor_summary("s3", c)
s4 = summary_lib.tensor_summary("TensorSummary", c)
- summ1, summ2, summ3, summ4 = sess.run([s1, s2, s3, s4])
+ summ1, summ2, summ3, summ4 = self.evaluate([s1, s2, s3, s4])
v1 = self._SummarySingleValue(summ1)
self.assertEqual(v1.tag, "s1")
diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py
index 589172e..97a280e 100644
--- a/tensorflow/python/kernel_tests/svd_op_test.py
+++ b/tensorflow/python/kernel_tests/svd_op_test.py
@@ -150,7 +150,7 @@
s_tf, u_tf, v_tf = linalg_ops.svd(
x_tf, compute_uv=compute_uv_, full_matrices=full_matrices_)
if use_static_shape_:
- s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf])
+ s_tf_val, u_tf_val, v_tf_val = self.evaluate([s_tf, u_tf, v_tf])
else:
s_tf_val, u_tf_val, v_tf_val = sess.run(
[s_tf, u_tf, v_tf], feed_dict={x_tf: x_np})
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 4ee1c27..bb8645e 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -1583,7 +1583,7 @@
# wrap it in the correct name scope.
dx, = gradients_impl.gradients(ys=[y], xs=[x], grad_ys=[dy])
with self.cached_session(use_gpu=True) as sess:
- vdx, vdy = sess.run([dx, dy])
+ vdx, vdy = self.evaluate([dx, dy])
self.assertAllClose(vdx, vdy)
def testSkipEagerTensorArrayInt64GPU(self):
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
index d9f340d..a72888c 100644
--- a/tensorflow/python/kernel_tests/topk_op_test.py
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -48,7 +48,7 @@
np_expected_indices = np.array(expected_indices)
with self.cached_session(use_gpu=True) as sess:
values_op, indices_op = nn_ops.top_k(inputs, k, sorted=sorted)
- values, indices = sess.run([values_op, indices_op])
+ values, indices = self.evaluate([values_op, indices_op])
self.assertShapeEqual(np_expected_values, values_op)
self.assertShapeEqual(np_expected_indices, indices_op)
diff --git a/tensorflow/python/kernel_tests/unicode_transcode_op_test.py b/tensorflow/python/kernel_tests/unicode_transcode_op_test.py
index d1c7b41..037ecd1 100644
--- a/tensorflow/python/kernel_tests/unicode_transcode_op_test.py
+++ b/tensorflow/python/kernel_tests/unicode_transcode_op_test.py
@@ -143,7 +143,7 @@
errors="strict")
with self.assertRaisesOpError(
"Invalid formatting on input string"):
- sess.run(outputs)
+ self.evaluate(outputs)
def test_transcode_bad_utf8_start_with_strict_errors(self):
bad_string = b"\xffabcd"
@@ -155,7 +155,7 @@
errors="strict")
with self.assertRaisesOpError(
"Invalid formatting on input string"):
- sess.run(outputs)
+ self.evaluate(outputs)
def test_transcode_bad_utf8_with_elision_of_malformatting(self):
bad_string = b"\x00\xff"
@@ -336,7 +336,7 @@
replace_control_characters=False)
with self.assertRaisesOpError(
"Could not create converter for input encoding: invalid"):
- sess.run(outputs)
+ self.evaluate(outputs)
with self.assertRaisesRegexp(ValueError, "Op passed string 'invalid'"):
with self.cached_session() as sess:
@@ -347,7 +347,7 @@
errors="replace",
replacement_char=ord(" "),
replace_control_characters=False)
- sess.run(outputs)
+ self.evaluate(outputs)
def test_invalid_error_policy_causes_errors(self):
strings = [[b"a", b"abc"], [b"ABC", b"DEF"]]
@@ -362,7 +362,7 @@
errors="invalid",
replacement_char=ord(" "),
replace_control_characters=False)
- sess.run(outputs)
+ self.evaluate(outputs)
def test_forwarding(self):
with self.cached_session():
diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py
index 316570e..f203263 100644
--- a/tensorflow/python/kernel_tests/unique_op_test.py
+++ b/tensorflow/python/kernel_tests/unique_op_test.py
@@ -32,7 +32,7 @@
x = np.random.randint(2, high=10, size=7000)
with self.cached_session() as sess:
y, idx = array_ops.unique(x)
- tf_y, tf_idx = sess.run([y, idx])
+ tf_y, tf_idx = self.evaluate([y, idx])
self.assertEqual(len(x), len(tf_idx))
self.assertEqual(len(tf_y), len(np.unique(x)))
@@ -43,7 +43,7 @@
x = np.random.randint(2, high=10, size=7000)
with self.cached_session() as sess:
y, idx = array_ops.unique(x, out_idx=dtypes.int64)
- tf_y, tf_idx = sess.run([y, idx])
+ tf_y, tf_idx = self.evaluate([y, idx])
self.assertEqual(len(x), len(tf_idx))
self.assertEqual(len(tf_y), len(np.unique(x)))
@@ -55,7 +55,7 @@
x = [chr(i) for i in indx]
with self.cached_session() as sess:
y, idx = array_ops.unique(x)
- tf_y, tf_idx = sess.run([y, idx])
+ tf_y, tf_idx = self.evaluate([y, idx])
self.assertEqual(len(x), len(tf_idx))
self.assertEqual(len(tf_y), len(np.unique(x)))
@@ -67,9 +67,9 @@
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
with self.cached_session() as sess:
y0, idx0 = gen_array_ops.unique_v2(x, axis=np.array([0], dtype))
- tf_y0, tf_idx0 = sess.run([y0, idx0])
+ tf_y0, tf_idx0 = self.evaluate([y0, idx0])
y1, idx1 = gen_array_ops.unique_v2(x, axis=np.array([1], dtype))
- tf_y1, tf_idx1 = sess.run([y1, idx1])
+ tf_y1, tf_idx1 = self.evaluate([y1, idx1])
self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
@@ -81,7 +81,7 @@
x = np.random.randint(2, high=10, size=7000)
with self.cached_session() as sess:
y, idx = gen_array_ops.unique_v2(x, axis=np.array([], np.int32))
- tf_y, tf_idx = sess.run([y, idx])
+ tf_y, tf_idx = self.evaluate([y, idx])
self.assertEqual(len(x), len(tf_idx))
self.assertEqual(len(tf_y), len(np.unique(x)))
@@ -95,7 +95,7 @@
x = np.random.randint(2, high=10, size=7000)
with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x)
- tf_y, tf_idx, tf_count = sess.run([y, idx, count])
+ tf_y, tf_idx, tf_count = self.evaluate([y, idx, count])
self.assertEqual(len(x), len(tf_idx))
self.assertEqual(len(tf_y), len(np.unique(x)))
@@ -108,7 +108,7 @@
x = np.random.randint(2, high=10, size=7000)
with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x, out_idx=dtypes.int64)
- tf_y, tf_idx, tf_count = sess.run([y, idx, count])
+ tf_y, tf_idx, tf_count = self.evaluate([y, idx, count])
self.assertEqual(len(x), len(tf_idx))
self.assertEqual(len(tf_y), len(np.unique(x)))
@@ -123,7 +123,7 @@
with self.cached_session() as sess:
y, idx, count = array_ops.unique_with_counts(x)
- tf_y, tf_idx, tf_count = sess.run([y, idx, count])
+ tf_y, tf_idx, tf_count = self.evaluate([y, idx, count])
self.assertEqual(len(x), len(tf_idx))
self.assertEqual(len(tf_y), len(np.unique(x)))
@@ -139,10 +139,10 @@
with self.cached_session() as sess:
y0, idx0, count0 = gen_array_ops.unique_with_counts_v2(
x, axis=np.array([0], dtype))
- tf_y0, tf_idx0, tf_count0 = sess.run([y0, idx0, count0])
+ tf_y0, tf_idx0, tf_count0 = self.evaluate([y0, idx0, count0])
y1, idx1, count1 = gen_array_ops.unique_with_counts_v2(
x, axis=np.array([1], dtype))
- tf_y1, tf_idx1, tf_count1 = sess.run([y1, idx1, count1])
+ tf_y1, tf_idx1, tf_count1 = self.evaluate([y1, idx1, count1])
self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
self.assertAllEqual(tf_count0, np.array([2, 1]))
@@ -157,7 +157,7 @@
with self.cached_session() as sess:
y, idx, count = gen_array_ops.unique_with_counts_v2(
x, axis=np.array([], np.int32))
- tf_y, tf_idx, tf_count = sess.run([y, idx, count])
+ tf_y, tf_idx, tf_count = self.evaluate([y, idx, count])
self.assertEqual(len(x), len(tf_idx))
self.assertEqual(len(tf_y), len(np.unique(x)))
diff --git a/tensorflow/python/kernel_tests/unstack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py
index 6aea429..d314e1e 100644
--- a/tensorflow/python/kernel_tests/unstack_op_test.py
+++ b/tensorflow/python/kernel_tests/unstack_op_test.py
@@ -41,7 +41,7 @@
def testSimple(self):
np.random.seed(7)
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
for dtype in [
np.bool, np.float16, np.float32, np.float64, np.int32, np.int64
@@ -53,14 +53,15 @@
cs = array_ops.unstack(x, num=shape[0])
self.assertEqual(type(cs), list)
self.assertEqual(len(cs), shape[0])
- cs = [c.eval() for c in cs]
+ cs = [self.evaluate(c) for c in cs]
self.assertAllEqual(cs, data)
def testSimpleGpu(self):
if not test_util.is_gpu_available():
self.skipTest('No GPU available')
+
np.random.seed(7)
- with self.session(use_gpu=True, force_gpu=True):
+ with test_util.force_gpu():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
for dtype in [np.float16, np.float32, np.float64, np.int32, np.int64]:
data = np.random.randn(*shape).astype(dtype)
@@ -70,7 +71,7 @@
cs = array_ops.unstack(x, num=shape[0])
self.assertEqual(type(cs), list)
self.assertEqual(len(cs), shape[0])
- cs = [c.eval() for c in cs]
+ cs = [self.evaluate(c) for c in cs]
self.assertAllEqual(cs, data)
def testGradientsAxis0(self):
@@ -131,15 +132,13 @@
for j in range(-i, i):
expected = np_split_squeeze(a, j)
- with self.cached_session() as sess:
- actual_unstack = sess.run(array_ops.unstack(a, axis=j))
+ actual_unstack = self.evaluate(array_ops.unstack(a, axis=j))
self.assertAllEqual(expected, actual_unstack)
def testAxis0Default(self):
- with self.cached_session() as sess:
- a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
- unstacked = sess.run(array_ops.unstack(a))
+ a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
+ unstacked = self.evaluate(array_ops.unstack(a))
self.assertEqual(len(unstacked), 2)
self.assertAllEqual(unstacked[0], [1, 2, 3])
@@ -156,10 +155,9 @@
array_ops.unstack(a, axis=-3)
def testZeroLengthDim(self):
- with self.cached_session():
- x = array_ops.zeros(shape=(0, 1, 2))
- y = array_ops.unstack(x, axis=1)[0].eval()
- self.assertEqual(y.shape, (0, 2))
+ x = array_ops.zeros(shape=(0, 1, 2))
+ y = self.evaluate(array_ops.unstack(x, axis=1)[0])
+ self.assertEqual(y.shape, (0, 2))
if __name__ == '__main__':
diff --git a/tensorflow/python/kernel_tests/variable_ops_test.py b/tensorflow/python/kernel_tests/variable_ops_test.py
index 769bbba..c63d7f8 100644
--- a/tensorflow/python/kernel_tests/variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/variable_ops_test.py
@@ -24,6 +24,7 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops
@@ -164,7 +165,7 @@
self.assertEqual(tensor_shape.unknown_shape(), subbed.get_shape())
def testTemporaryVariable(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
var = gen_state_ops.temporary_variable(
[1, 2], dtypes.float32, var_name="foo")
var = state_ops.assign(var, [[4.0, 5.0]])
@@ -173,14 +174,14 @@
self.assertAllClose([[10.0, 12.0]], self.evaluate(final))
def testDestroyNonexistentTemporaryVariable(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
var = gen_state_ops.temporary_variable([1, 2], dtypes.float32)
final = gen_state_ops.destroy_temporary_variable(var, var_name="bad")
with self.assertRaises(errors.NotFoundError):
self.evaluate(final)
def testDuplicateTemporaryVariable(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
var1 = gen_state_ops.temporary_variable(
[1, 2], dtypes.float32, var_name="dup")
var1 = state_ops.assign(var1, [[1.0, 2.0]])
@@ -192,7 +193,7 @@
self.evaluate(final)
def testDestroyTemporaryVariableTwice(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
var = gen_state_ops.temporary_variable([1, 2], dtypes.float32)
val1 = gen_state_ops.destroy_temporary_variable(var, var_name="dup")
val2 = gen_state_ops.destroy_temporary_variable(var, var_name="dup")
@@ -201,14 +202,14 @@
self.evaluate(final)
def testTemporaryVariableNoLeak(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
var = gen_state_ops.temporary_variable(
[1, 2], dtypes.float32, var_name="bar")
final = array_ops.identity(var)
self.evaluate(final)
def testTwoTemporaryVariablesNoLeaks(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
var1 = gen_state_ops.temporary_variable(
[1, 2], dtypes.float32, var_name="var1")
var2 = gen_state_ops.temporary_variable(
@@ -217,13 +218,13 @@
self.evaluate(final)
def testAssignDependencyAcrossDevices(self):
- with self.test_session(use_gpu=True):
+ with test_util.use_gpu():
# The variable and an op to increment it are on the GPU.
var = state_ops.variable_op([1], dtypes.float32)
- state_ops.assign(var, [1.0]).eval()
+ self.evaluate(state_ops.assign(var, [1.0]))
increment = state_ops.assign_add(var, [1.0])
with ops.control_dependencies([increment]):
- with ops.device("/cpu:0"):
+ with test_util.force_cpu():
# This mul op is pinned to the CPU, but reads the variable from the
# GPU. The test ensures that the dependency on 'increment' is still
# honored, i.e., the Send and Recv from GPU to CPU should take place
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 838838e..3720f73 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -308,7 +308,6 @@
self.evaluate(variables_lib.global_variables_initializer())
self.assertAllEqual(self.evaluate(x.value()), self.evaluate(y.value()))
- # TODO(alive): support variable partitioning/caching in eager mode.
# TODO(mihaimaruseac): Not converted to use wrap_function because of
# InvalidArgumentError: /job:moo/replica:0/task:0/device:CPU:0 unknown device.
def testVarScopeCachingDevice(self):
@@ -435,19 +434,19 @@
add = v1 + v0
# v0 should be uninitialized.
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
- sess.run(v0)
+ self.evaluate(v0)
# We should be able to initialize and run v1 without initializing
# v0, even if the variable was created with a control dep on v0.
self.evaluate(v1.initializer)
self.assertEqual(1, self.evaluate(v1))
# v0 should still be uninitialized.
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
- sess.run(v0)
+ self.evaluate(v0)
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
- sess.run(add)
+ self.evaluate(add)
# If we initialize v0 we should be able to run 'add'.
self.evaluate(v0.initializer)
- sess.run(add)
+ self.evaluate(add)
# TODO(mihaimaruseac): Not converted to use wrap_function because of
# AssertionError: True is not false (last assertFalse)
@@ -496,13 +495,13 @@
self.assertEqual([2], self.evaluate(v2))
# v0 should still be uninitialized.
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
- sess.run(v0)
+ self.evaluate(v0)
# We should not be able to run 'add' yet.
with self.assertRaisesRegexp(errors.OpError, "uninitialized"):
- sess.run(add)
+ self.evaluate(add)
# If we initialize v0 we should be able to run 'add'.
self.evaluate(v0.initializer)
- sess.run(add)
+ self.evaluate(add)
# TODO(mihaimaruseac): Not converted to use wrap_function because of
# TypeError: Expected tf.group() expected Tensor arguments not 'None' with
@@ -1315,6 +1314,28 @@
@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
+ def testGetVariableWithInitializerWhichTakesNoArgs(self):
+ v = variable_scope.get_variable("foo", initializer=lambda: [2])
+ self.assertEqual(v.name, "foo:0")
+
+ @test_util.run_in_graph_and_eager_modes
+ @run_inside_wrap_function_in_eager_mode
+ def testGetVariableWithInitializerWhichTakesOptionalArgs(self):
+ v = variable_scope.get_variable("foo", initializer=lambda x=True: [2])
+ self.assertEqual(v.name, "foo:0")
+
+ @test_util.run_in_graph_and_eager_modes
+ @run_inside_wrap_function_in_eager_mode
+ def testGetVariableWithInitializerWhichTakesUnprovidedArgsAndNoShape(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ "The initializer passed is not valid. It should be a callable with no "
+ "arguments and the shape should not be provided or an instance of "
+ "`tf.keras.initializers.*' and `shape` should be fully defined."):
+ variable_scope.get_variable("foo", initializer=lambda x: [2])
+
+ @test_util.run_in_graph_and_eager_modes
+ @run_inside_wrap_function_in_eager_mode
def testTwoGraphs(self):
def f():
@@ -1404,6 +1425,14 @@
v_reused = variable_scope.get_variable("name0")
self.assertEqual(v, v_reused)
+ def testNoReuseInEagerByDefault(self):
+ with context.eager_mode():
+ with variable_scope.variable_scope(
+ "scope0", partitioner=axis0_into2_partitioner):
+ v1 = variable_scope.get_variable("name0", shape=(3, 1, 1))
+ v2 = variable_scope.get_variable("name0", shape=(3, 1, 1))
+ self.assertIsNot(v1, v2)
+
@test_util.run_in_graph_and_eager_modes
@run_inside_wrap_function_in_eager_mode
def testPropagatePartitionerOnReopening(self):
@@ -1459,6 +1488,10 @@
def testPartitionConcatenatesAlongCorrectAxisResource(self):
self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
+ def testPartitionConcatenatesAlongCorrectAxisResourceInEager(self):
+ with context.eager_mode():
+ self._testPartitionConcatenatesAlongCorrectAxis(use_resource=True)
+
class VariableScopeWithCustomGetterTest(test.TestCase):
@@ -1569,7 +1602,7 @@
self.assertEqual("custom_getter/add:0", v.name)
with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
- np_vars, np_v = sess.run([true_vars, v])
+ np_vars, np_v = self.evaluate([true_vars, v])
self.assertAllClose(np_v, sum(np_vars))
# TODO(mihaimaruseac): Not converted to use wrap_function because of
@@ -1614,7 +1647,7 @@
with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
- np_vars, np_v = sess.run([true_vars, v])
+ np_vars, np_v = self.evaluate([true_vars, v])
# take products of sums of products
self.assertAllClose(
np_v, (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) + (
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2bb7510..14ec46d 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
+import functools
import operator
import numpy as np
@@ -227,13 +228,13 @@
self.assertEqual([2], self.evaluate(v2))
# v0 should still be uninitialized.
with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
- sess.run(v0)
+ self.evaluate(v0)
# We should not be able to run 'add' yet.
with self.assertRaisesRegexp(errors_impl.OpError, "uninitialized"):
- sess.run(add)
+ self.evaluate(add)
# If we initialize v0 we should be able to run 'add'.
self.evaluate(v0.initializer)
- sess.run(add)
+ self.evaluate(add)
def testControlFlowInitialization(self):
"""Expects an error if an initializer is in a control-flow scope."""
@@ -309,6 +310,12 @@
self.assertEqual([var_x], variables.trainable_variables("scope_1"))
self.assertEqual([var_y], variables.trainable_variables("scope_2"))
+ def testOperatorWrapping(self):
+ for attr in functools.WRAPPER_ASSIGNMENTS:
+ self.assertEqual(
+ getattr(variables.Variable.__add__, attr),
+ getattr(ops.Tensor.__add__, attr))
+
def testOperators(self):
with self.cached_session():
var_f = variables.Variable([2.0])
@@ -469,11 +476,11 @@
with ops.Graph().as_default(), self.cached_session() as sess:
# v describes a VariableDef-based variable without an initial value.
v = variables.Variable(variable_def=v_def)
- self.assertEqual(3.0, sess.run(v.initialized_value()))
+ self.assertEqual(3.0, self.evaluate(v.initialized_value()))
# initialized_value should not rerun the initializer_op if the variable
# has already been initialized elsewhere.
- sess.run(v.assign(1.0))
+ self.evaluate(v.assign(1.0))
self.assertEqual(1.0, v.initialized_value().eval())
v_def.ClearField("initial_value_name")
@@ -485,7 +492,7 @@
self.assertProtoEquals(v_def, v.to_proto())
# But attempts to use initialized_value will result in errors.
with self.assertRaises(ValueError):
- sess.run(v.initialized_value())
+ self.evaluate(v.initialized_value())
def testTrainableInProto(self):
with ops.Graph().as_default():
@@ -572,7 +579,7 @@
variables.global_variables_initializer().run()
do_opt = gradient_descent.GradientDescentOptimizer(0.1).minimize(
objective)
- sess.run([do_opt])
+ self.evaluate([do_opt])
self.assertAllClose([[0.9, 0.9], [0.9, 0.9]], self.evaluate(b))
@@ -589,9 +596,9 @@
_ = v, w
inited = variables.assert_variables_initialized()
with self.assertRaisesOpError("Attempting to use uninitialized value"):
- sess.run(inited)
+ self.evaluate(inited)
variables.global_variables_initializer().run()
- sess.run(inited)
+ self.evaluate(inited)
def testVariableList(self):
with ops.Graph().as_default(), self.cached_session() as sess:
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index bd31421..77669f0 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -56,7 +56,7 @@
with self.cached_session(use_gpu=use_gpu) as sess:
loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
np_features, np_labels)
- tf_loss, tf_backprop = sess.run([loss, backprop])
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
@@ -80,7 +80,7 @@
loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
np.array([[1.], [-1.], [0.]]).astype(dtype),
np.array([[-1.], [0.], [1.]]).astype(dtype))
- tf_loss, tf_backprop = sess.run([loss, backprop])
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllClose([0.0, 0.0, 0.0], tf_loss)
self.assertAllClose([[2.0], [1.0], [0.0]], tf_backprop)
@@ -148,7 +148,7 @@
with self.cached_session(use_gpu=use_gpu) as sess:
loss, backprop = gen_nn_ops.softmax_cross_entropy_with_logits(
tf_f, tf_l)
- tf_loss, tf_backprop = sess.run([loss, backprop])
+ tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
self.assertAllCloseAccordingToType(np_backprop, tf_backprop)
diff --git a/tensorflow/python/kernel_tests/zero_division_test.py b/tensorflow/python/kernel_tests/zero_division_test.py
index 73ab382..7c82f93 100644
--- a/tensorflow/python/kernel_tests/zero_division_test.py
+++ b/tensorflow/python/kernel_tests/zero_division_test.py
@@ -21,13 +21,14 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class ZeroDivisionTest(test.TestCase):
def testZeros(self):
- with self.session(use_gpu=True):
+ with test_util.use_gpu():
for dtype in dtypes.uint8, dtypes.int16, dtypes.int32, dtypes.int64:
zero = constant_op.constant(0, dtype=dtype)
one = constant_op.constant(1, dtype=dtype)
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index fccea48..bfe591f 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -23,6 +23,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.util import function_utils
@@ -30,10 +31,10 @@
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-
+# Avoid breaking users who directly import this symbol from this file.
+# TODO(fchollet): remove this.
InputSpec = base_layer.InputSpec # pylint: disable=invalid-name
-
_KERAS_STYLE_SCOPE = False
@@ -242,11 +243,11 @@
def _make_unique_name(self, name_uid_map=None, avoid_names=None,
namespace='', zero_based=False):
base_name = base_layer.to_snake_case(self.__class__.__name__)
- name = base_layer.unique_layer_name(base_name,
- name_uid_map=name_uid_map,
- avoid_names=avoid_names,
- namespace=namespace,
- zero_based=zero_based)
+ name = base_layer_utils.unique_layer_name(base_name,
+ name_uid_map=name_uid_map,
+ avoid_names=avoid_names,
+ namespace=namespace,
+ zero_based=zero_based)
return (name, base_name)
@property
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 90abf35..4509967 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -26,6 +26,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import base_layer as keras_base_layer
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.layers import base as base_layers
from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import array_ops
@@ -251,7 +252,7 @@
def __init__(self):
super(CustomerLayer, self).__init__()
- self.input_spec = base_layers.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
def call(self, inputs):
return inputs
@@ -278,7 +279,7 @@
def __init__(self):
super(CustomerLayer, self).__init__()
- self.input_spec = base_layers.InputSpec(min_ndim=2)
+ self.input_spec = input_spec.InputSpec(min_ndim=2)
def call(self, inputs):
return inputs
@@ -306,7 +307,7 @@
def __init__(self):
super(CustomerLayer, self).__init__()
- self.input_spec = base_layers.InputSpec(max_ndim=2)
+ self.input_spec = input_spec.InputSpec(max_ndim=2)
def call(self, inputs):
return inputs
@@ -334,7 +335,7 @@
def __init__(self):
super(CustomerLayer, self).__init__()
- self.input_spec = base_layers.InputSpec(dtype='float32')
+ self.input_spec = input_spec.InputSpec(dtype='float32')
def call(self, inputs):
return inputs
@@ -354,7 +355,7 @@
def __init__(self):
super(CustomerLayer, self).__init__()
- self.input_spec = base_layers.InputSpec(axes={-1: 2})
+ self.input_spec = input_spec.InputSpec(axes={-1: 2})
def call(self, inputs):
return inputs
@@ -376,7 +377,7 @@
def __init__(self):
super(CustomerLayer, self).__init__()
- self.input_spec = base_layers.InputSpec(shape=(None, 3))
+ self.input_spec = input_spec.InputSpec(shape=(None, 3))
def call(self, inputs):
return inputs
diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py
index 11a2ebc..93eec38 100644
--- a/tensorflow/python/layers/layers.py
+++ b/tensorflow/python/layers/layers.py
@@ -24,7 +24,7 @@
# Base objects.
from tensorflow.python.layers.base import Layer
-from tensorflow.python.layers.base import InputSpec
+from tensorflow.python.keras.engine.input_spec import InputSpec
# Core layers.
from tensorflow.python.layers.core import Dense
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index febc358..cc3badb 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -323,7 +323,7 @@
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1))
np_beta = np.reshape(np_beta, (1, 4, 1))
@@ -336,7 +336,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 2))
std = np.std(np_inputs, axis=(0, 2))
@@ -364,7 +365,7 @@
with self.cached_session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 3))
np_beta = np.reshape(np_beta, (1, 1, 3))
for _ in range(100):
@@ -376,7 +377,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 1))
std = np.std(np_inputs, axis=(0, 1))
@@ -405,7 +407,7 @@
with self.session(use_gpu=True) as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
for _ in range(100):
@@ -417,7 +419,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 2, 3))
std = np.std(np_inputs, axis=(0, 2, 3))
@@ -445,7 +448,7 @@
with self.cached_session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 3, 1))
np_beta = np.reshape(np_beta, (1, 1, 3, 1))
for _ in range(100):
@@ -457,7 +460,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 1, 3))
std = np.std(np_inputs, axis=(0, 1, 3))
@@ -485,7 +489,7 @@
with self.cached_session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
@@ -497,7 +501,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 1, 2))
std = np.std(np_inputs, axis=(0, 1, 2))
@@ -525,7 +530,7 @@
with self.cached_session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
@@ -537,7 +542,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 1, 2))
std = np.std(np_inputs, axis=(0, 1, 2))
@@ -566,7 +572,7 @@
with self.cached_session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 4, 1, 1))
np_beta = np.reshape(np_beta, (1, 4, 1, 1))
for _ in range(100):
@@ -578,7 +584,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 2, 3))
std = np.std(np_inputs, axis=(0, 2, 3))
@@ -606,7 +613,7 @@
with self.cached_session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
@@ -619,7 +626,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 1, 2))
std = np.std(np_inputs, axis=(0, 1, 2))
@@ -647,7 +655,7 @@
with self.cached_session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
@@ -658,7 +666,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 1, 2))
std = np.std(np_inputs, axis=(0, 1, 2))
@@ -697,7 +706,7 @@
with self.cached_session() as sess:
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([gamma, beta])
+ np_gamma, np_beta = self.evaluate([gamma, beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
for _ in range(100):
@@ -709,7 +718,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- np_moving_mean, np_moving_var = sess.run([moving_mean, moving_variance])
+ np_moving_mean, np_moving_var = self.evaluate(
+ [moving_mean, moving_variance])
np_inputs = self.evaluate(inputs)
np_mean = np.mean(np_inputs, axis=(0, 1, 2))
np_std = np.std(np_inputs, axis=(0, 1, 2))
@@ -764,7 +774,8 @@
feed_dict={training: True})
# Verify that the statistics are updated during training.
- np_moving_mean, np_moving_var = sess.run([moving_mean, moving_variance])
+ np_moving_mean, np_moving_var = self.evaluate(
+ [moving_mean, moving_variance])
np_inputs = self.evaluate(inputs2)
np_mean = np.mean(np_inputs, axis=(0, 1, 2))
np_std = np.std(np_inputs, axis=(0, 1, 2))
@@ -773,7 +784,7 @@
self.assertAllClose(np_variance, np_moving_var, atol=1e-2)
# Verify that the axis is normalized during training.
- np_gamma, np_beta = sess.run([gamma, beta])
+ np_gamma, np_beta = self.evaluate([gamma, beta])
np_gamma = np.reshape(np_gamma, (1, 1, 1, 6))
np_beta = np.reshape(np_beta, (1, 1, 1, 6))
normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta
@@ -1258,7 +1269,7 @@
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates,
@@ -1269,7 +1280,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=0, keepdims=True)
std = np.std(np_inputs, axis=0, keepdims=True)
@@ -1298,7 +1310,7 @@
# Test training with placeholder learning phase.
self.evaluate(variables.global_variables_initializer())
- np_gamma, np_beta = sess.run([bn.gamma, bn.beta])
+ np_gamma, np_beta = self.evaluate([bn.gamma, bn.beta])
for _ in range(100):
np_output, _, _ = sess.run([outputs] + bn.updates,
@@ -1309,7 +1321,8 @@
self.assertAlmostEqual(np.std(normed_np_output), 1., places=1)
# Verify that the statistics are updated during training.
- moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance])
+ moving_mean, moving_var = self.evaluate(
+ [bn.moving_mean, bn.moving_variance])
np_inputs = self.evaluate(inputs)
mean = np.mean(np_inputs, axis=(0, 4), keepdims=True)
std = np.std(np_inputs, axis=(0, 4), keepdims=True)
diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py
index c8aa531..645cf8e 100644
--- a/tensorflow/python/lib/io/file_io.py
+++ b/tensorflow/python/lib/io/file_io.py
@@ -258,7 +258,7 @@
return file_exists_v2(filename)
-@tf_export("io.gfile.exists", v1=[])
+@tf_export("io.gfile.exists")
def file_exists_v2(path):
"""Determines whether a path exists or not.
@@ -280,7 +280,7 @@
return True
-@tf_export("gfile.Remove")
+@tf_export(v1=["gfile.Remove"])
def delete_file(filename):
"""Deletes the file located at 'filename'.
@@ -291,8 +291,22 @@
errors.OpError: Propagates any errors reported by the FileSystem API. E.g.,
NotFoundError if the file does not exist.
"""
+ delete_file_v2(filename)
+
+
+@tf_export("io.gfile.remove")
+def delete_file_v2(path):
+ """Deletes the path located at 'path'.
+
+ Args:
+ path: string, a path
+
+ Raises:
+ errors.OpError: Propagates any errors reported by the FileSystem API. E.g.,
+ NotFoundError if the path does not exist.
+ """
with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.DeleteFile(compat.as_bytes(filename), status)
+ pywrap_tensorflow.DeleteFile(compat.as_bytes(path), status)
def read_file_to_string(filename, binary_mode=False):
@@ -331,7 +345,7 @@
f.write(file_content)
-@tf_export("gfile.Glob")
+@tf_export(v1=["gfile.Glob"])
def get_matching_files(filename):
"""Returns a list of files that match the given pattern(s).
@@ -344,25 +358,41 @@
Raises:
errors.OpError: If there are filesystem / directory listing errors.
"""
+ return get_matching_files_v2(filename)
+
+
+@tf_export("io.gfile.glob")
+def get_matching_files_v2(pattern):
+ """Returns a list of files that match the given pattern(s).
+
+ Args:
+ pattern: string or iterable of strings. The glob pattern(s).
+
+ Returns:
+ A list of strings containing filenames that match the given pattern(s).
+
+ Raises:
+ errors.OpError: If there are filesystem / directory listing errors.
+ """
with errors.raise_exception_on_not_ok_status() as status:
- if isinstance(filename, six.string_types):
+ if isinstance(pattern, six.string_types):
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
- compat.as_bytes(filename), status)
+ compat.as_bytes(pattern), status)
]
else:
return [
# Convert the filenames to string from bytes.
compat.as_str_any(matching_filename)
- for single_filename in filename
+ for single_filename in pattern
for matching_filename in pywrap_tensorflow.GetMatchingFiles(
compat.as_bytes(single_filename), status)
]
-@tf_export("gfile.MkDir")
+@tf_export(v1=["gfile.MkDir"])
def create_dir(dirname):
"""Creates a directory with the name 'dirname'.
@@ -376,11 +406,28 @@
Raises:
errors.OpError: If the operation fails.
"""
+ create_dir_v2(dirname)
+
+
+@tf_export("io.gfile.mkdir")
+def create_dir_v2(path):
+ """Creates a directory with the name given by 'path'.
+
+ Args:
+ path: string, name of the directory to be created
+
+ Notes:
+ The parent directories need to exist. Use recursive_create_dir instead if
+ there is the possibility that the parent dirs don't exist.
+
+ Raises:
+ errors.OpError: If the operation fails.
+ """
with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.CreateDir(compat.as_bytes(dirname), status)
+ pywrap_tensorflow.CreateDir(compat.as_bytes(path), status)
-@tf_export("gfile.MakeDirs")
+@tf_export(v1=["gfile.MakeDirs"])
def recursive_create_dir(dirname):
"""Creates a directory and all parent/intermediate directories.
@@ -392,11 +439,26 @@
Raises:
errors.OpError: If the operation fails.
"""
+ recursive_create_dir_v2(dirname)
+
+
+@tf_export("io.gfile.makedirs")
+def recursive_create_dir_v2(path):
+ """Creates a directory and all parent/intermediate directories.
+
+ It succeeds if path already exists and is writable.
+
+ Args:
+ path: string, name of the directory to be created
+
+ Raises:
+ errors.OpError: If the operation fails.
+ """
with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(dirname), status)
+ pywrap_tensorflow.RecursivelyCreateDir(compat.as_bytes(path), status)
-@tf_export("gfile.Copy")
+@tf_export(v1=["gfile.Copy"])
def copy(oldpath, newpath, overwrite=False):
"""Copies data from oldpath to newpath.
@@ -409,12 +471,28 @@
Raises:
errors.OpError: If the operation fails.
"""
+ copy_v2(oldpath, newpath, overwrite)
+
+
+@tf_export("io.gfile.copy")
+def copy_v2(src, dst, overwrite=False):
+ """Copies data from src to dst.
+
+ Args:
+ src: string, name of the file whose contents need to be copied
+ dst: string, name of the file to which to copy to
+ overwrite: boolean, if false its an error for newpath to be occupied by an
+ existing file.
+
+ Raises:
+ errors.OpError: If the operation fails.
+ """
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.CopyFile(
- compat.as_bytes(oldpath), compat.as_bytes(newpath), overwrite, status)
+ compat.as_bytes(src), compat.as_bytes(dst), overwrite, status)
-@tf_export("gfile.Rename")
+@tf_export(v1=["gfile.Rename"])
def rename(oldname, newname, overwrite=False):
"""Rename or move a file / directory.
@@ -427,9 +505,25 @@
Raises:
errors.OpError: If the operation fails.
"""
+ rename_v2(oldname, newname, overwrite)
+
+
+@tf_export("io.gfile.rename")
+def rename_v2(src, dst, overwrite):
+ """Rename or move a file / directory.
+
+ Args:
+ src: string, pathname for a file
+ dst: string, pathname to which the file needs to be moved
+ overwrite: boolean, if false it's an error for `dst` to be occupied by
+ an existing file.
+
+ Raises:
+ errors.OpError: If the operation fails.
+ """
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.RenameFile(
- compat.as_bytes(oldname), compat.as_bytes(newname), overwrite, status)
+ compat.as_bytes(src), compat.as_bytes(dst), overwrite, status)
def atomic_write_string_to_file(filename, contents, overwrite=True):
@@ -456,7 +550,7 @@
raise
-@tf_export("gfile.DeleteRecursively")
+@tf_export(v1=["gfile.DeleteRecursively"])
def delete_recursively(dirname):
"""Deletes everything under dirname recursively.
@@ -466,11 +560,24 @@
Raises:
errors.OpError: If the operation fails.
"""
+ delete_recursively_v2(dirname)
+
+
+@tf_export("io.gfile.rmtree")
+def delete_recursively_v2(path):
+ """Deletes everything under path recursively.
+
+ Args:
+ path: string, a path
+
+ Raises:
+ errors.OpError: If the operation fails.
+ """
with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.DeleteRecursively(compat.as_bytes(dirname), status)
+ pywrap_tensorflow.DeleteRecursively(compat.as_bytes(path), status)
-@tf_export("gfile.IsDirectory")
+@tf_export(v1=["gfile.IsDirectory"])
def is_directory(dirname):
"""Returns whether the path is a directory or not.
@@ -480,11 +587,24 @@
Returns:
True, if the path is a directory; False otherwise
"""
+ return is_directory_v2(dirname)
+
+
+@tf_export("io.gfile.isdir")
+def is_directory_v2(path):
+ """Returns whether the path is a directory or not.
+
+ Args:
+ path: string, path to a potential directory
+
+ Returns:
+ True, if the path is a directory; False otherwise
+ """
status = c_api_util.ScopedTFStatus()
- return pywrap_tensorflow.IsDirectory(compat.as_bytes(dirname), status)
+ return pywrap_tensorflow.IsDirectory(compat.as_bytes(path), status)
-@tf_export("gfile.ListDirectory")
+@tf_export(v1=["gfile.ListDirectory"])
def list_directory(dirname):
"""Returns a list of entries contained within a directory.
@@ -500,7 +620,26 @@
Raises:
errors.NotFoundError if directory doesn't exist
"""
- if not is_directory(dirname):
+ return list_directory_v2(dirname)
+
+
+@tf_export("io.gfile.listdir")
+def list_directory_v2(path):
+ """Returns a list of entries contained within a directory.
+
+ The list is in arbitrary order. It does not contain the special entries "."
+ and "..".
+
+ Args:
+ path: string, path to a directory
+
+ Returns:
+ [filename1, filename2, ... filenameN] as strings
+
+ Raises:
+ errors.NotFoundError if directory doesn't exist
+ """
+ if not is_directory(path):
raise errors.NotFoundError(None, None, "Could not find directory")
with errors.raise_exception_on_not_ok_status() as status:
# Convert each element to string, since the return values of the
@@ -508,11 +647,11 @@
return [
compat.as_str_any(filename)
for filename in pywrap_tensorflow.GetChildren(
- compat.as_bytes(dirname), status)
+ compat.as_bytes(path), status)
]
-@tf_export("gfile.Walk")
+@tf_export(v1=["gfile.Walk"])
def walk(top, in_order=True):
"""Recursive directory tree generator for directories.
@@ -528,11 +667,35 @@
(dirname, [subdirname, subdirname, ...], [filename, filename, ...])
as strings
"""
+ return walk_v2(top, in_order)
+
+
+@tf_export("io.gfile.walk")
+def walk_v2(top, topdown, onerror=None):
+ """Recursive directory tree generator for directories.
+
+ Args:
+ top: string, a Directory name
+ topdown: bool, Traverse pre order if True, post order if False.
+ onerror: optional handler for errors. Should be a function, it will be
+ called with the error as argument. Rethrowing the error aborts the walk.
+
+ Errors that happen while listing directories are ignored.
+
+ Yields:
+ Each yield is a 3-tuple: the pathname of a directory, followed by lists of
+ all its subdirectories and leaf files.
+ (dirname, [subdirname, subdirname, ...], [filename, filename, ...])
+ as strings
+ """
top = compat.as_str_any(top)
try:
listing = list_directory(top)
- except errors.NotFoundError:
- return
+ except errors.NotFoundError as err:
+ if onerror:
+ onerror(err)
+ else:
+ return
files = []
subdirs = []
@@ -545,18 +708,18 @@
here = (top, subdirs, files)
- if in_order:
+ if topdown:
yield here
for subdir in subdirs:
- for subitem in walk(os.path.join(top, subdir), in_order):
+ for subitem in walk_v2(os.path.join(top, subdir), topdown, onerror=onerror):
yield subitem
- if not in_order:
+ if not topdown:
yield here
-@tf_export("gfile.Stat")
+@tf_export(v1=["gfile.Stat"])
def stat(filename):
"""Returns file statistics for a given path.
@@ -569,9 +732,25 @@
Raises:
errors.OpError: If the operation fails.
"""
+ return stat_v2(filename)
+
+
+@tf_export("io.gfile.stat")
+def stat_v2(path):
+ """Returns file statistics for a given path.
+
+ Args:
+ path: string, path to a file
+
+ Returns:
+ FileStatistics struct that contains information about the path
+
+ Raises:
+ errors.OpError: If the operation fails.
+ """
file_statistics = pywrap_tensorflow.FileStatistics()
with errors.raise_exception_on_not_ok_status() as status:
- pywrap_tensorflow.Stat(compat.as_bytes(filename), file_statistics, status)
+ pywrap_tensorflow.Stat(compat.as_bytes(path), file_statistics, status)
return file_statistics
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index b7fae85..43086ab 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -150,10 +150,11 @@
return options
-@tf_export(
- "io.tf_record_iterator",
- v1=["io.tf_record_iterator", "python_io.tf_record_iterator"])
-@deprecation.deprecated_endpoints("python_io.tf_record_iterator")
+@tf_export(v1=["io.tf_record_iterator", "python_io.tf_record_iterator"])
+@deprecation.deprecated(
+ date=None,
+ instructions=("Use eager execution and: \n"
+ "`tf.data.TFRecordDataset(path)`"))
def tf_record_iterator(path, options=None):
"""An iterator that read the records from a TFRecords file.
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 68c392b..6edc193 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -489,10 +489,12 @@
@ops.RegisterGradient("CheckNumerics")
-def _CheckNumericsGrad(_, grad):
+def _CheckNumericsGrad(op, grad):
"""Gradient for check_numerics op."""
return array_ops.check_numerics(
- grad, "Not a number (NaN) or infinity (Inf) values detected in gradient.")
+ grad,
+ "Not a number (NaN) or infinity (Inf) values detected in gradient. %s" %
+ op.get_attr("message"))
@ops.RegisterGradient("PlaceholderWithDefault")
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 2a7989e..496d385 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1507,7 +1507,75 @@
value=value, size_splits=size_splits, axis=axis, num_split=num, name=name)
-@tf_export("transpose")
+@tf_export("transpose", v1=[])
+def transpose_v2(a, perm=None, conjugate=False, name="transpose"):
+ """Transposes `a`. Permutes the dimensions according to `perm`.
+
+ The returned tensor's dimension i will correspond to the input dimension
+ `perm[i]`. If `perm` is not given, it is set to (n-1...0), where n is
+ the rank of the input tensor. Hence by default, this operation performs a
+ regular matrix transpose on 2-D input Tensors. If conjugate is True and
+ `a.dtype` is either `complex64` or `complex128` then the values of `a`
+ are conjugated and transposed.
+
+ @compatibility(numpy)
+ In `numpy` transposes are memory-efficient constant time operations as they
+ simply return a new view of the same data with adjusted `strides`.
+
+ TensorFlow does not support strides, so `transpose` returns a new tensor with
+ the items permuted.
+ @end_compatibility
+
+ For example:
+
+ ```python
+ x = tf.constant([[1, 2, 3], [4, 5, 6]])
+ tf.transpose(x) # [[1, 4]
+ # [2, 5]
+ # [3, 6]]
+
+ # Equivalently
+ tf.transpose(x, perm=[1, 0]) # [[1, 4]
+ # [2, 5]
+ # [3, 6]]
+
+ # If x is complex, setting conjugate=True gives the conjugate transpose
+ x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
+ [4 + 4j, 5 + 5j, 6 + 6j]])
+ tf.transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
+ # [2 - 2j, 5 - 5j],
+ # [3 - 3j, 6 - 6j]]
+
+ # 'perm' is more useful for n-dimensional tensors, for n > 2
+ x = tf.constant([[[ 1, 2, 3],
+ [ 4, 5, 6]],
+ [[ 7, 8, 9],
+ [10, 11, 12]]])
+
+ # Take the transpose of the matrices in dimension-0
+ # (this common operation has a shorthand `linalg.transpose`)
+ tf.transpose(x, perm=[0, 2, 1]) # [[[1, 4],
+ # [2, 5],
+ # [3, 6]],
+ # [[7, 10],
+ # [8, 11],
+ # [9, 12]]]
+ ```
+
+ Args:
+ a: A `Tensor`.
+ perm: A permutation of the dimensions of `a`.
+ conjugate: Optional bool. Setting it to `True` is mathematically equivalent
+ to tf.conj(tf.transpose(input)).
+ name: A name for the operation (optional).
+
+ Returns:
+ A transposed `Tensor`.
+ """
+ return transpose(a=a, perm=perm, name=name, conjugate=conjugate)
+
+
+@tf_export(v1=["transpose"])
def transpose(a, perm=None, name="transpose", conjugate=False):
"""Transposes `a`. Permutes the dimensions according to `perm`.
@@ -2607,7 +2675,7 @@
depth_to_space.__doc__ = gen_array_ops.depth_to_space.__doc__
-@tf_export("batch_to_space")
+@tf_export(v1=["batch_to_space"])
def batch_to_space(input, crops, block_size, name=None): # pylint: disable=redefined-builtin
result = batch_to_space_nd(
input,
@@ -2621,6 +2689,151 @@
batch_to_space.__doc__ = gen_array_ops.batch_to_space.__doc__
+@tf_export("batch_to_space", v1=[])
+def batch_to_space_v2(input, block_shape, crops, name=None): # pylint: disable=redefined-builtin
+ """BatchToSpace for N-D tensors of type T.
+
+ This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of
+ shape `block_shape + [batch]`, interleaves these blocks back into the grid
+ defined by the spatial dimensions `[1, ..., M]`, to obtain a result with the
+ same rank as the input. The spatial dimensions of this intermediate result
+ are then optionally cropped according to `crops` to produce the output. This
+ is the reverse of SpaceToBatch. See below for a precise description.
+
+ Args:
+ input: A `Tensor`.
+ N-D with shape `input_shape = [batch] + spatial_shape + remaining_shape`,
+ where spatial_shape has M dimensions.
+ block_shape: A `Tensor`. Must be one of the following types:
+ `int32`, `int64`. 1-D with shape `[M]`, all values must be >= 1.
+ For backwards compatibility with TF 1.0, this parameter may be an int, in
+ which case it is converted to
+ `numpy.array([block_shape, block_shape], dtype=numpy.int64)`.
+ crops: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ 2-D with shape `[M, 2]`, all values must be >= 0.
+ `crops[i] = [crop_start, crop_end]` specifies the amount to crop from
+ input dimension `i + 1`, which corresponds to spatial dimension `i`. It
+ is required that
+ `crop_start[i] + crop_end[i] <= block_shape[i] * input_shape[i + 1]`.
+
+ This operation is equivalent to the following steps:
+
+ 1. Reshape `input` to `reshaped` of shape:
+ [block_shape[0], ..., block_shape[M-1],
+ batch / prod(block_shape),
+ input_shape[1], ..., input_shape[N-1]]
+
+ 2. Permute dimensions of `reshaped` to produce `permuted` of shape
+ [batch / prod(block_shape),
+
+ input_shape[1], block_shape[0],
+ ...,
+ input_shape[M], block_shape[M-1],
+
+ input_shape[M+1], ..., input_shape[N-1]]
+
+ 3. Reshape `permuted` to produce `reshaped_permuted` of shape
+ [batch / prod(block_shape),
+
+ input_shape[1] * block_shape[0],
+ ...,
+ input_shape[M] * block_shape[M-1],
+
+ input_shape[M+1],
+ ...,
+ input_shape[N-1]]
+
+ 4. Crop the start and end of dimensions `[1, ..., M]` of
+ `reshaped_permuted` according to `crops` to produce the
+ output of shape:
+ [batch / prod(block_shape),
+
+ input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
+ ...,
+ input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
+
+ input_shape[M+1], ..., input_shape[N-1]]
+
+ Some examples:
+
+ (1) For the following input of shape `[4, 1, 1, 1]`,
+ `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
+
+ ```
+ [[[[1]]], [[[2]]], [[[3]]], [[[4]]]]
+ ```
+
+ The output tensor has shape `[1, 2, 2, 1]` and value:
+
+ ```
+ x = [[[[1], [2]], [[3], [4]]]]
+ ```
+
+ (2) For the following input of shape `[4, 1, 1, 3]`,
+ `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
+
+ ```
+ [[[1, 2, 3]], [[4, 5, 6]], [[7, 8, 9]], [[10, 11, 12]]]
+ ```
+
+ The output tensor has shape `[1, 2, 2, 3]` and value:
+
+ ```
+ x = [[[[1, 2, 3], [4, 5, 6]],
+ [[7, 8, 9], [10, 11, 12]]]]
+ ```
+
+ (3) For the following input of shape `[4, 2, 2, 1]`,
+ `block_shape = [2, 2]`, and `crops = [[0, 0], [0, 0]]`:
+
+ ```
+ x = [[[[1], [3]], [[9], [11]]],
+ [[[2], [4]], [[10], [12]]],
+ [[[5], [7]], [[13], [15]]],
+ [[[6], [8]], [[14], [16]]]]
+ ```
+
+ The output tensor has shape `[1, 4, 4, 1]` and value:
+
+ ```
+ x = [[[1], [2], [3], [4]],
+ [[5], [6], [7], [8]],
+ [[9], [10], [11], [12]],
+ [[13], [14], [15], [16]]]
+ ```
+
+ (4) For the following input of shape `[8, 1, 3, 1]`,
+ `block_shape = [2, 2]`, and `crops = [[0, 0], [2, 0]]`:
+
+ ```
+ x = [[[[0], [1], [3]]], [[[0], [9], [11]]],
+ [[[0], [2], [4]]], [[[0], [10], [12]]],
+ [[[0], [5], [7]]], [[[0], [13], [15]]],
+ [[[0], [6], [8]]], [[[0], [14], [16]]]]
+ ```
+
+ The output tensor has shape `[2, 2, 4, 1]` and value:
+
+ ```
+ x = [[[[1], [2], [3], [4]],
+ [[5], [6], [7], [8]]],
+ [[[9], [10], [11], [12]],
+ [[13], [14], [15], [16]]]]
+ ```
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `input`.
+ """
+ if isinstance(block_shape, int):
+ block_shape = np.array([block_shape, block_shape], dtype=np.int64)
+
+ return batch_to_space_nd(input=input,
+ block_shape=block_shape,
+ crops=crops,
+ name=name)
+
+
@tf_export("one_hot")
def one_hot(indices,
depth,
@@ -2844,7 +3057,7 @@
return gen_math_ops.cast(result, dtype)
-@tf_export("squeeze")
+@tf_export(v1=["squeeze"])
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"squeeze_dims")
def squeeze(input, axis=None, name=None, squeeze_dims=None):
@@ -2894,6 +3107,12 @@
return gen_array_ops.squeeze(input, axis, name)
+@tf_export("squeeze", v1=[])
+def squeeze_v2(input, axis=None, name=None):
+ # pylint: disable=redefined-builtin
+ return squeeze(input, axis, name)
+
+
@tf_export("where")
def where(condition, x=None, y=None, name=None):
"""Return the elements, either from `x` or `y`, depending on the `condition`.
@@ -2997,7 +3216,7 @@
# pylint: enable=redefined-builtin
-@tf_export("gather")
+@tf_export(v1=["gather"])
def gather(params, indices, validate_indices=None, name=None, axis=0):
del validate_indices
if axis != 0:
@@ -3013,7 +3232,14 @@
return gen_array_ops.gather_v2(params, indices, axis, name=name)
-gather.__doc__ = gen_array_ops.gather_v2.__doc__
+@tf_export("gather", v1=[])
+def gather_v2(params, indices, validate_indices=None, axis=0, name=None):
+ return gather(params, indices, validate_indices=validate_indices, name=name,
+ axis=axis)
+
+
+gather.__doc__ = gather_v2.__doc__ = gen_array_ops.gather_v2.__doc__
+
@tf_export("batch_gather")
@@ -3201,3 +3427,48 @@
quantize.__doc__ = gen_array_ops.quantize_v2.__doc__
+
+
+@tf_export("image.extract_image_patches", v1=[])
+def extract_image_patches_v2(
+ images,
+ sizes,
+ strides,
+ rates,
+ padding,
+ name=None):
+ # pylint: disable=line-too-long
+ r"""Extract `patches` from `images` and put them in the \"depth\" output dimension.
+
+ Args:
+ images: A 4-D Tensor with shape `[batch, in_rows, in_cols, depth]
+ sizes: The size of the sliding window for each dimension of `images`.
+ strides: A 1-D Tensor of length 4. How far the centers of two consecutive
+ patches are in the images. Must be: `[1, stride_rows, stride_cols, 1]`.
+ rates: A 1-D Tensor of length 4. Must be: `[1, rate_rows, rate_cols, 1]`.
+ This is the input stride, specifying how far two consecutive patch samples
+ are in the input. Equivalent to extracting patches with `patch_sizes_eff =
+ patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by subsampling
+ them spatially by a factor of `rates`. This is equivalent to `rate` in
+ dilated (a.k.a. Atrous) convolutions.
+ padding: The type of padding algorithm to use.
+ We specify the size-related attributes as: ```python ksizes = [1,
+ ksize_rows, ksize_cols, 1] strides = [1, strides_rows, strides_cols, 1]
+ rates = [1, rates_rows, rates_cols, 1]
+ name: A name for the operation (optional).
+
+ Returns:
+ A 4-D Tensor. Has the same type as `images`, and with shape `[batch,
+ out_rows, out_cols, ksize_rows * ksize_cols * depth]` containing image
+ patches with size `ksize_rows x ksize_cols x depth` vectorized in the
+ \"depth\" dimension. Note `out_rows` and `out_cols` are the dimensions of
+ the output patches.
+ """
+ # pylint: enable=line-too-long
+ return gen_array_ops.extract_image_patches(
+ images, sizes, strides, rates, padding, name)
+
+extract_image_patches_deprecation = deprecation.deprecated_args(
+ None, "ksizes is deprecated, use sizes instead", "ksizes")
+tf_export(v1=["image.extract_image_patches", "extract_image_patches"])(
+ extract_image_patches_deprecation(gen_array_ops.extract_image_patches))
diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py
index dfb40db..7392782 100644
--- a/tensorflow/python/ops/bitwise_ops_test.py
+++ b/tensorflow/python/ops/bitwise_ops_test.py
@@ -59,14 +59,15 @@
2**31 - 1, 2**31, 2**32 - 1, 2**32, -2**32 + 1, -2**32,
-2**63 + 1, 2**63 - 1]
def count_bits(x):
- return sum([bin(z).count("1") for z in six.iterbytes(x.tobytes())])
+ return sum(bin(z).count("1") for z in six.iterbytes(x.tobytes()))
for dtype in dtype_list:
with self.cached_session(use_gpu=True) as sess:
print("PopulationCount test: ", dtype)
inputs = np.array(raw_inputs, dtype=dtype.as_numpy_dtype)
truth = [count_bits(x) for x in inputs]
input_tensor = constant_op.constant(inputs, dtype=dtype)
- popcnt_result = sess.run(gen_bitwise_ops.population_count(input_tensor))
+ popcnt_result = self.evaluate(
+ gen_bitwise_ops.population_count(input_tensor))
self.assertAllEqual(truth, popcnt_result)
def testInvertOp(self):
@@ -89,7 +90,7 @@
self.assertAllEqual(not_a_or_a, [not_0] * 4)
# For unsigned dtypes let's also check the result directly.
if dtype.is_unsigned:
- inverted = sess.run(bitwise_ops.invert(input_tensor))
+ inverted = self.evaluate(bitwise_ops.invert(input_tensor))
expected = [dtype.max - x for x in inputs]
self.assertAllEqual(inverted, expected)
diff --git a/tensorflow/python/ops/clip_ops_test.py b/tensorflow/python/ops/clip_ops_test.py
index 8aa9c4f..e9f7941 100644
--- a/tensorflow/python/ops/clip_ops_test.py
+++ b/tensorflow/python/ops/clip_ops_test.py
@@ -35,7 +35,7 @@
input_op = constant_op.constant(inputs)
clipped = clip_ops.clip_by_norm(input_op, max_norm)
check_op = numerics.add_check_numerics_ops()
- result, _ = sess.run([clipped, check_op])
+ result, _ = self.evaluate([clipped, check_op])
self.assertAllClose(result, expected)
def _testClipIndexedSlicesByNorm(self, values, indices, shape, max_norm,
@@ -54,7 +54,7 @@
# Tensor mode
dense_tensor = ops.convert_to_tensor(indixed_slices)
dense_clipped = clip_ops.clip_by_norm(dense_tensor, max_norm, axes)
- result, expected = sess.run([clipped, dense_clipped])
+ result, expected = self.evaluate([clipped, dense_clipped])
self.assertAllClose(result, expected)
def testClipTensorByNorm(self):
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 0f08c61..927c649 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -25,12 +25,14 @@
import collections
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util_v2 as util
from tensorflow.python.ops import gen_functional_ops
+from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.util import nest
@@ -74,55 +76,14 @@
false_name, read_only_collections=False),
add_control_dependencies=add_control_dependencies,
op_return_value=pred)
- _check_same_outputs(true_graph, false_graph)
- # Add inputs to true_graph and false_graph to make them match. Note that
- # this modifies true_graph and false_graph.
- cond_inputs = _make_inputs_match(true_graph, false_graph,
- true_graph.external_captures,
- false_graph.external_captures)
-
- # Add all intermediate tensors as function outputs so they're available for
- # the gradient computation.
-
- true_intermediates = _get_intermediates(true_graph)
- false_intermediates = _get_intermediates(false_graph)
-
- # Save the original number of outputs to return to the caller.
- num_cond_outputs = len(true_graph.outputs)
-
- # Make the number/type of new intermediate outputs match.
- extra_true_outputs, extra_false_outputs = _pad_params(
- true_graph, false_graph, true_intermediates, false_intermediates)
-
- true_graph.outputs.extend(extra_true_outputs)
- false_graph.outputs.extend(extra_false_outputs)
-
- # Create the If op.
- tensors = gen_functional_ops._if( # pylint: disable=protected-access
- pred,
- cond_inputs, [t.dtype for t in true_graph.outputs],
- util.create_new_tf_function(true_graph),
- util.create_new_tf_function(false_graph),
- output_shapes=_get_output_shapes(true_graph.outputs,
- false_graph.outputs),
- name=scope)
-
- # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
- util.maybe_set_lowering_attr(tensors[0].op)
-
- # Return identities for each output of the If op, rather than the output of
- # the If op directly. This makes pruning work if the output of cond() is
- # fetched: the lowering pass converts the If outputs into IdentityN outputs,
- # which if fetched will cause all ops in the taken branch to be run (since
- # it takes all merge ops as input). After lowering, each output identity op
- # will end up with only the appropriate merge op as input.
- # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
- # correct output structure
- tensors = tuple(array_ops.identity(t) for t in tensors)
+ outputs = _build_cond(pred, true_graph, false_graph,
+ true_graph.external_captures,
+ false_graph.external_captures,
+ name=scope)
return func_graph_module.pack_sequence_as(true_graph.structured_outputs,
- tensors[:num_cond_outputs])
+ outputs)
@ops.RegisterGradient("If")
@@ -150,44 +111,83 @@
true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
- # Make the inputs to true_grad_graph and false_grad_graph match. Note that
- # this modifies true_grad_graph and false_grad_graph.
- grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
- true_grad_inputs, false_grad_inputs)
-
- # Add all intermediate tensors as function outputs so they're available for
- # higher-order gradient computations.
-
- true_grad_intermediates = _get_intermediates(true_grad_graph)
- false_grad_intermediates = _get_intermediates(false_grad_graph)
-
- # Save the original number of gradient outputs to return.
- num_grad_outputs = len(true_grad_graph.outputs)
-
- # Make the number/type of new intermediate outputs match.
- extra_true_grad_outputs, extra_false_grad_outputs = _pad_params(
- true_grad_graph, false_grad_graph,
- true_grad_intermediates, false_grad_intermediates)
-
- true_grad_graph.outputs.extend(extra_true_grad_outputs)
- false_grad_graph.outputs.extend(extra_false_grad_outputs)
-
- # Create the gradient If op.
- tensors = gen_functional_ops._if(
- op.inputs[0],
- grad_inputs, [t.dtype for t in true_grad_graph.outputs],
- util.create_new_tf_function(true_grad_graph),
- util.create_new_tf_function(false_grad_graph),
- output_shapes=_get_output_shapes(true_grad_graph.outputs,
- false_grad_graph.outputs))
-
- util.maybe_set_lowering_attr(tensors[0].op)
-
- # See comment in cond_v2.
- tensors = [array_ops.identity(t) for t in tensors]
+ outputs = _build_cond(op.inputs[0], true_grad_graph, false_grad_graph,
+ true_grad_inputs, false_grad_inputs)
# The predicate has no gradient.
- return [None] + tensors[:num_grad_outputs]
+ return [None] + outputs
+
+
+def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,
+ name=None):
+ """Creates an If op from the specified predicate, branch functions and inputs.
+
+ Note that this modifies true_graph and false_graph to make the inputs match,
+ and to output all intermediates values so they're available for the gradient
+ computation.
+
+ true_graph and false_graph need not have the same input types, but they must
+ have the same outpute types.
+
+ Args:
+ pred: boolean Tensor
+ true_graph: FuncGraph
+ false_graph: FuncGraph
+ true_inputs: a list of Tensors to be passed to true_graph as input.
+ false_inputs: a list of Tensors to be passed to false_graph as input.
+ name: the name for the If op.
+
+ Returns:
+ A list of Tensors which are the outputs of the If op. Does not include added
+ intermediate outputs.
+ """
+ _check_same_outputs(true_graph, false_graph)
+
+ # Add inputs to true_graph and false_graph to make them match. Note that
+ # this modifies true_graph and false_graph.
+ cond_inputs = _make_inputs_match(true_graph, false_graph,
+ true_inputs, false_inputs)
+
+ # Add all intermediate tensors as function outputs so they're available for
+ # the gradient computation.
+
+ true_intermediates = _get_intermediates(true_graph)
+ false_intermediates = _get_intermediates(false_graph)
+
+ # Save the original number of outputs to return to the caller.
+ num_cond_outputs = len(true_graph.outputs)
+
+ # Make the number/type of new intermediate outputs match.
+ extra_true_outputs, extra_false_outputs = _pad_params(
+ true_graph, false_graph, true_intermediates, false_intermediates)
+
+ true_graph.outputs.extend(extra_true_outputs)
+ false_graph.outputs.extend(extra_false_outputs)
+
+ # Create the If op.
+ tensors = gen_functional_ops._if( # pylint: disable=protected-access
+ pred,
+ cond_inputs, [t.dtype for t in true_graph.outputs],
+ util.create_new_tf_function(true_graph),
+ util.create_new_tf_function(false_graph),
+ output_shapes=_get_output_shapes(true_graph.outputs,
+ false_graph.outputs),
+ name=name)
+
+ # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output
+ util.maybe_set_lowering_attr(tensors[0].op)
+
+ # Return identities for each output of the If op, rather than the output of
+ # the If op directly. This makes pruning work if the output of cond() is
+ # fetched: the lowering pass converts the If outputs into IdentityN outputs,
+ # which if fetched will cause all ops in the taken branch to be run (since
+ # it takes all merge ops as input). After lowering, each output identity op
+ # will end up with only the appropriate merge op as input.
+ # TODO(b/79984175): this doesn't have to be a tuple once we covert to the
+ # correct output structure
+ tensors = [array_ops.identity(t) for t in tensors]
+
+ return tensors[:num_cond_outputs]
def _get_func_graphs(if_op):
@@ -264,7 +264,11 @@
# both branches have zero gradient.
for i in range(len(result)):
if result[i] is None:
- result[i] = array_ops.zeros_like(func_graph.inputs[i])
+ if func_graph.inputs[i].dtype == dtypes.resource:
+ result[i] = array_ops.zeros(
+ gen_resource_variable_ops.variable_shape(func_graph.inputs[i]))
+ else:
+ result[i] = array_ops.zeros_like(func_graph.inputs[i])
return result
diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py
index b86b174..ccfe3b6 100644
--- a/tensorflow/python/ops/confusion_matrix.py
+++ b/tensorflow/python/ops/confusion_matrix.py
@@ -90,12 +90,13 @@
return labels, predictions
-@tf_export(
- 'math.confusion_matrix',
- v1=['math.confusion_matrix', 'confusion_matrix'])
-@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix')
-def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
- name=None, weights=None):
+@tf_export('math.confusion_matrix', v1=[])
+def confusion_matrix(labels,
+ predictions,
+ num_classes=None,
+ weights=None,
+ dtype=dtypes.int32,
+ name=None):
"""Computes the confusion matrix from predictions and labels.
The matrix columns represent the prediction labels and the rows represent the
@@ -132,9 +133,9 @@
num_classes: The possible number of labels the classification task can
have. If this value is not provided, it will be calculated
using both predictions and labels array.
+ weights: An optional `Tensor` whose shape matches `predictions`.
dtype: Data type of the confusion matrix.
name: Scope name.
- weights: An optional `Tensor` whose shape matches `predictions`.
Returns:
A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion
@@ -193,3 +194,65 @@
zero_matrix = array_ops.zeros(math_ops.to_int32(shape), dtype)
return sparse_ops.sparse_add(zero_matrix, cm_sparse)
+
+
+@tf_export(v1=['math.confusion_matrix', 'confusion_matrix'])
+@deprecation.deprecated_endpoints('confusion_matrix', 'train.confusion_matrix')
+def confusion_matrix_v1(labels,
+ predictions,
+ num_classes=None,
+ dtype=dtypes.int32,
+ name=None,
+ weights=None):
+ """Computes the confusion matrix from predictions and labels.
+
+ The matrix columns represent the prediction labels and the rows represent the
+ real labels. The confusion matrix is always a 2-D array of shape `[n, n]`,
+ where `n` is the number of valid labels for a given classification task. Both
+ prediction and labels must be 1-D arrays of the same shape in order for this
+ function to work.
+
+ If `num_classes` is `None`, then `num_classes` will be set to one plus the
+ maximum value in either predictions or labels. Class labels are expected to
+ start at 0. For example, if `num_classes` is 3, then the possible labels
+ would be `[0, 1, 2]`.
+
+ If `weights` is not `None`, then each prediction contributes its
+ corresponding weight to the total value of the confusion matrix cell.
+
+ For example:
+
+ ```python
+ tf.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
+ [[0 0 0 0 0]
+ [0 0 1 0 0]
+ [0 0 1 0 0]
+ [0 0 0 0 0]
+ [0 0 0 0 1]]
+ ```
+
+ Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,
+ resulting in a 5x5 confusion matrix.
+
+ Args:
+ labels: 1-D `Tensor` of real labels for the classification task.
+ predictions: 1-D `Tensor` of predictions for a given classification.
+ num_classes: The possible number of labels the classification task can have.
+ If this value is not provided, it will be calculated using both
+ predictions and labels array.
+ dtype: Data type of the confusion matrix.
+ name: Scope name.
+ weights: An optional `Tensor` whose shape matches `predictions`.
+
+ Returns:
+ A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion
+ matrix, where `n` is the number of possible labels in the classification
+ task.
+
+ Raises:
+ ValueError: If both predictions and labels are not 1-D vectors and have
+ mismatched shapes, or if `weights` is not `None` and its shape doesn't
+ match `predictions`.
+ """
+ return confusion_matrix(labels, predictions, num_classes, weights, dtype,
+ name)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index eab9b3f..a36a24e 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -158,7 +158,7 @@
with ops.name_scope(name, "Assert", [condition, data]) as name:
xs = ops.convert_n_to_tensor(data)
- if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]):
+ if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
# As a simple heuristic, we assume that string and int32 are
# on host to avoid the need to use cond. If it is not case,
# we will pay the price copying the tensor to host memory.
@@ -457,19 +457,19 @@
ValueError: If any of the inputs is None, or inputs are IndexedSlices and
some but not all have a dense_shape property.
"""
- if any([inp is None for inp in inputs]):
+ if any(inp is None for inp in inputs):
raise ValueError("At least one of the merge inputs is None: %s" % inputs)
with ops.name_scope(name, "Merge", inputs) as name:
inputs = [
ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
for inp in inputs
]
- if all([isinstance(v, ops.Tensor) for v in inputs]):
- if all([v.dtype._is_ref_dtype for v in inputs]): # pylint: disable=protected-access
+ if all(isinstance(v, ops.Tensor) for v in inputs):
+ if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access
return gen_control_flow_ops.ref_merge(inputs, name)
else:
return gen_control_flow_ops.merge(inputs, name)
- elif all([isinstance(v, sparse_tensor.SparseTensor) for v in inputs]):
+ elif all(isinstance(v, sparse_tensor.SparseTensor) for v in inputs):
# Only handle the case when all inputs are SparseTensor.
values, _ = merge([inp.values for inp in inputs], name=name)
indices, chosen_index = gen_control_flow_ops.merge(
@@ -557,7 +557,7 @@
if shapes is None:
return
flat_shapes = nest.flatten(shapes)
- if not all([isinstance(s, tensor_shape.TensorShape) for s in flat_shapes]):
+ if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes):
raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
# Check that the shapes of the inputs are less than the shape invariants,
# and set the shapes of `enter_vars` to the shape invariants.
@@ -3136,7 +3136,193 @@
# pylint: disable=redefined-outer-name
-@tf_export("while_loop")
+@tf_export("while_loop", v1=[])
+def while_loop_v2(cond,
+ body,
+ loop_vars,
+ shape_invariants=None,
+ parallel_iterations=10,
+ back_prop=True,
+ swap_memory=False,
+ maximum_iterations=None,
+ return_same_structure=False,
+ name=None):
+ """Repeat `body` while the condition `cond` is true.
+
+ `cond` is a callable returning a boolean scalar tensor. `body` is a callable
+ returning a (possibly nested) tuple, namedtuple or list of tensors of the same
+ arity (length and structure) and types as `loop_vars`. `loop_vars` is a
+ (possibly nested) tuple, namedtuple or list of tensors that is passed to both
+ `cond` and `body`. `cond` and `body` both take as many arguments as there are
+ `loop_vars`.
+
+ In addition to regular Tensors or IndexedSlices, the body may accept and
+ return TensorArray objects. The flows of the TensorArray objects will
+ be appropriately forwarded between loops and during gradient calculations.
+
+ Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
+ call to `while_loop`, and not at all during `Session.run()`). `while_loop`
+ stitches together the graph fragments created during the `cond` and `body`
+ calls with some additional graph nodes to create the graph flow that
+ repeats `body` until `cond` returns false.
+
+ For correctness, `tf.while_loop()` strictly enforces shape invariants for
+ the loop variables. A shape invariant is a (possibly partial) shape that
+ is unchanged across the iterations of the loop. An error will be raised
+ if the shape of a loop variable after an iteration is determined to be more
+ general than or incompatible with its shape invariant. For example, a shape
+ of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
+ compatible with [11, 17]. By default (if the argument `shape_invariants` is
+ not specified), it is assumed that the initial shape of each tensor in
+ `loop_vars` is the same in every iteration. The `shape_invariants` argument
+ allows the caller to specify a less specific shape invariant for each loop
+ variable, which is needed if the shape varies between iterations. The
+ `tf.Tensor.set_shape`
+ function may also be used in the `body` function to indicate that
+ the output loop variable has a particular shape. The shape invariant for
+ SparseTensor and IndexedSlices are treated specially as follows:
+
+ a) If a loop variable is a SparseTensor, the shape invariant must be
+ TensorShape([r]) where r is the rank of the dense tensor represented
+ by the sparse tensor. It means the shapes of the three tensors of the
+ SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
+ is the shape of the SparseTensor.dense_shape property. It must be the shape of
+ a vector.
+
+ b) If a loop variable is an IndexedSlices, the shape invariant must be
+ a shape invariant of the values tensor of the IndexedSlices. It means
+ the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
+ [shape.ndims]).
+
+ `while_loop` implements non-strict semantics, enabling multiple iterations
+ to run in parallel. The maximum number of parallel iterations can be
+ controlled by `parallel_iterations`, which gives users some control over
+ memory consumption and execution order. For correct programs, `while_loop`
+ should return the same result for any parallel_iterations > 0.
+
+ For training, TensorFlow stores the tensors that are produced in the
+ forward inference and are needed in back propagation. These tensors are a
+ main source of memory consumption and often cause OOM errors when training
+ on GPUs. When the flag swap_memory is true, we swap out these tensors from
+ GPU to CPU. This for example allows us to train RNN models with very long
+ sequences and large batches.
+
+ Args:
+ cond: A callable that represents the termination condition of the loop.
+ body: A callable that represents the loop body.
+ loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
+ `Tensor`, and `TensorArray` objects.
+ shape_invariants: The shape invariants for the loop variables.
+ parallel_iterations: The number of iterations allowed to run in parallel. It
+ must be a positive integer.
+ back_prop: Whether backprop is enabled for this while loop.
+ swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
+ maximum_iterations: Optional maximum number of iterations of the while loop
+ to run. If provided, the `cond` output is AND-ed with an additional
+ condition ensuring the number of iterations executed is no greater than
+ `maximum_iterations`.
+ return_same_structure: If True, output has same structure as `loop_vars`. If
+ eager execution is enabled, this is ignored (and always treated as True).
+ name: Optional name prefix for the returned tensors.
+
+ Returns:
+ The output tensors for the loop variables after the loop.
+ If `return_same_structure` is True, the return value has the same
+ structure as `loop_vars`.
+ If `return_same_structure` is False, the return value is a Tensor,
+ TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
+ otherwise.
+
+ Raises:
+ TypeError: if `cond` or `body` is not callable.
+ ValueError: if `loop_vars` is empty.
+
+ Example:
+
+ ```python
+ i = tf.constant(0)
+ c = lambda i: tf.less(i, 10)
+ b = lambda i: tf.add(i, 1)
+ r = tf.while_loop(c, b, [i])
+ ```
+
+ Example with nesting and a namedtuple:
+
+ ```python
+ import collections
+ Pair = collections.namedtuple('Pair', 'j, k')
+ ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
+ c = lambda i, p: i < 10
+ b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
+ ijk_final = tf.while_loop(c, b, ijk_0)
+ ```
+
+ Example using shape_invariants:
+
+ ```python
+ i0 = tf.constant(0)
+ m0 = tf.ones([2, 2])
+ c = lambda i, m: i < 10
+ b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
+ tf.while_loop(
+ c, b, loop_vars=[i0, m0],
+ shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
+ ```
+
+ Example which demonstrates non-strict semantics: In the following
+ example, the final value of the counter `i` does not depend on `x`. So
+ the `while_loop` can increment the counter parallel to updates of `x`.
+ However, because the loop counter at one loop iteration depends
+ on the value at the previous iteration, the loop counter itself cannot
+ be incremented in parallel. Hence if we just want the final value of the
+ counter (which we print on the line `print(sess.run(i))`), then
+ `x` will never be incremented, but the counter will be updated on a
+ single thread. Conversely, if we want the value of the output (which we
+ print on the line `print(sess.run(out).shape)`), then the counter may be
+ incremented on its own thread, while `x` can be incremented in
+ parallel on a separate thread. In the extreme case, it is conceivable
+ that the thread incrementing the counter runs until completion before
+ `x` is incremented even a single time. The only thing that can never
+ happen is that the thread updating `x` can never get ahead of the
+ counter thread because the thread incrementing `x` depends on the value
+ of the counter.
+
+ ```python
+ import tensorflow as tf
+
+ n = 10000
+ x = tf.constant(list(range(n)))
+ c = lambda i, x: i < n
+ b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
+ i, out = tf.while_loop(c, b, (0, x))
+ with tf.Session() as sess:
+ print(sess.run(i)) # prints [0] ... [9999]
+
+ # The following line may increment the counter and x in parallel.
+ # The counter thread may get ahead of the other thread, but not the
+ # other way around. So you may see things like
+ # [9996] x:[9987]
+ # meaning that the counter thread is on iteration 9996,
+ # while the other thread is on iteration 9987
+ print(sess.run(out).shape)
+ ```
+
+ """
+ return while_loop(
+ cond=cond,
+ body=body,
+ loop_vars=loop_vars,
+ shape_invariants=shape_invariants,
+ parallel_iterations=parallel_iterations,
+ back_prop=back_prop,
+ swap_memory=swap_memory,
+ name=name,
+ maximum_iterations=maximum_iterations,
+ return_same_structure=return_same_structure)
+
+
+# pylint: disable=redefined-outer-name
+@tf_export(v1=["while_loop"])
def while_loop(cond,
body,
loop_vars,
@@ -3536,7 +3722,43 @@
return no_op(name=name)
-@tf_export("tuple")
+@tf_export("tuple", v1=[])
+def tuple_v2(tensors, control_inputs=None, name=None):
+ """Group tensors together.
+
+ This creates a tuple of tensors with the same values as the `tensors`
+ argument, except that the value of each tensor is only returned after the
+ values of all tensors have been computed.
+
+ `control_inputs` contains additional ops that have to finish before this op
+ finishes, but whose outputs are not returned.
+
+ This can be used as a "join" mechanism for parallel computations: all the
+ argument tensors can be computed in parallel, but the values of any tensor
+ returned by `tuple` are only available after all the parallel computations
+ are done.
+
+ See also `tf.group` and
+ `tf.control_dependencies`.
+
+ Args:
+ tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
+ control_inputs: List of additional ops to finish before returning.
+ name: (optional) A name to use as a `name_scope` for the operation.
+
+ Returns:
+ Same as `tensors`.
+
+ Raises:
+ ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
+ TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
+ objects.
+
+ """
+ return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin
+
+
+@tf_export(v1=["tuple"])
def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin
"""Group tensors together.
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 260af95..c020189 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -211,7 +211,7 @@
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(10):
- sess.run([train_op])
+ self.evaluate([train_op])
def testResourceReadInLoop(self):
with ops.Graph().as_default():
@@ -270,7 +270,7 @@
with self.cached_session() as sess:
self.evaluate(variables.global_variables_initializer())
- self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
+ self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads]))
def testIndexedSlicesGradientInCondInWhileLoop(self):
self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False)
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index e1071af..3a7eb93 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -19,17 +19,27 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_ctc_ops
+from tensorflow.python.ops import inplace_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.nn_grad import _BroadcastMul
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access, invalid-name
-@tf_export("nn.ctc_loss")
+@tf_export(v1=["nn.ctc_loss"])
def ctc_loss(labels, inputs, sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
@@ -336,6 +346,785 @@
ops.NotDifferentiable("CTCGreedyDecoder")
-
-
ops.NotDifferentiable("CTCBeamSearchDecoder")
+
+
+def _ctc_state_trans(label_seq):
+ """Compute CTC alignment model transition matrix.
+
+ Args:
+ label_seq: tensor of shape [batch_size, max_seq_length]
+
+ Returns:
+ tensor of shape [batch_size, states, states] with a state transition matrix
+ computed for each sequence of the batch.
+ """
+
+ with ops.name_scope("ctc_state_trans"):
+ label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
+ batch_size = _get_dim(label_seq, 0)
+ num_labels = _get_dim(label_seq, 1)
+
+ num_label_states = num_labels + 1
+ num_states = 2 * num_label_states
+
+ label_states = math_ops.range(num_label_states)
+ blank_states = label_states + num_label_states
+
+ # Start state to first label.
+ start_to_label = [[1, 0]]
+
+ # Blank to label transitions.
+ blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1)
+
+ # Label to blank transitions.
+ label_to_blank = array_ops.stack([blank_states, label_states], 1)
+
+ # Scatter transitions that don't depend on sequence.
+ indices = array_ops.concat(
+ [start_to_label, blank_to_label, label_to_blank], 0)
+ values = array_ops.ones([_get_dim(indices, 0)])
+ trans = array_ops.scatter_nd(
+ indices, values, shape=[num_states, num_states])
+ trans += linalg_ops.eye(num_states) # Self-loops.
+
+ # Label to label transitions. Disallow transitions between repeated labels
+ # with no blank state in between.
+ batch_idx = array_ops.zeros_like(label_states[2:])
+ indices = array_ops.stack(
+ [batch_idx, label_states[2:], label_states[1:-1]], 1)
+ indices = array_ops.tile(
+ array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
+ batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
+ indices += array_ops.expand_dims(batch_idx, 1)
+ repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
+ values = 1.0 - math_ops.cast(repeats, dtypes.float32)
+ batched_shape = [batch_size, num_states, num_states]
+ label_to_label = array_ops.scatter_nd(indices, values, batched_shape)
+
+ return array_ops.expand_dims(trans, 0) + label_to_label
+
+
+def ctc_state_log_probs(seq_lengths, max_seq_length):
+ """Computes CTC alignment initial and final state log probabilities.
+
+ Create the initial/final state values directly as log values to avoid
+ having to take a float64 log on tpu (which does not exist).
+
+ Args:
+ seq_lengths: int tensor of shape [batch_size], seq lengths in the batch.
+ max_seq_length: int, max sequence length possible.
+
+ Returns:
+ initial_state_log_probs, final_state_log_probs
+ """
+
+ batch_size = _get_dim(seq_lengths, 0)
+ num_label_states = max_seq_length + 1
+ num_duration_states = 2
+ num_states = num_duration_states * num_label_states
+ log_0 = math_ops.cast(
+ math_ops.log(math_ops.cast(0, dtypes.float64) + 1e-307),
+ dtypes.float32)
+
+ initial_state_log_probs = array_ops.one_hot(
+ indices=array_ops.zeros([batch_size], dtype=dtypes.int32),
+ depth=num_states,
+ on_value=0.0,
+ off_value=log_0, axis=1)
+
+ label_final_state_mask = array_ops.one_hot(
+ seq_lengths, depth=num_label_states, axis=0)
+ duration_final_state_mask = array_ops.ones(
+ [num_duration_states, 1, batch_size])
+ final_state_mask = duration_final_state_mask * label_final_state_mask
+ final_state_log_probs = (1.0 - final_state_mask) * log_0
+ final_state_log_probs = array_ops.reshape(
+ final_state_log_probs, [num_states, batch_size])
+
+ return initial_state_log_probs, array_ops.transpose(final_state_log_probs)
+
+
+def _ilabel_to_state(labels, num_labels, ilabel_log_probs):
+ """Project ilabel log probs to state log probs."""
+
+ num_label_states = _get_dim(labels, 1)
+ blank = ilabel_log_probs[:, :, :1]
+ blank = array_ops.tile(blank, [1, 1, num_label_states + 1])
+ one_hot = array_ops.one_hot(labels, depth=num_labels)
+ one_hot = array_ops.expand_dims(one_hot, axis=0)
+ ilabel_log_probs = array_ops.expand_dims(ilabel_log_probs, axis=2)
+ state_log_probs = math_ops.reduce_sum(ilabel_log_probs * one_hot, axis=3)
+ state_log_probs = array_ops.concat([state_log_probs, blank], axis=2)
+ return array_ops.pad(
+ state_log_probs, [[0, 0], [0, 0], [1, 0]],
+ constant_values=math_ops.log(0.0))
+
+
+def _state_to_olabel(labels, num_labels, states):
+ """Sum state log probs to ilabel log probs."""
+
+ num_label_states = _get_dim(labels, 1) + 1
+ label_states = states[:, :, 1:num_label_states]
+ blank_states = states[:, :, num_label_states:]
+ one_hot = array_ops.one_hot(
+ labels - 1, depth=(num_labels - 1),
+ on_value=0.0, off_value=math_ops.log(0.0))
+ one_hot = array_ops.expand_dims(one_hot, axis=0)
+ label_states = array_ops.expand_dims(label_states, axis=3)
+ label_olabels = math_ops.reduce_logsumexp(label_states + one_hot, axis=2)
+ blank_olabels = math_ops.reduce_logsumexp(
+ blank_states, axis=2, keepdims=True)
+ return array_ops.concat([blank_olabels, label_olabels], axis=-1)
+
+
+# pylint: disable=redefined-outer-name
+def _state_to_olabel_unique(labels, num_labels, states, unique):
+ """Sum state log probs to ilabel log probs using unique label indices."""
+
+ num_label_states = _get_dim(labels, 1) + 1
+ label_states = states[:, :, 1:num_label_states]
+ blank_states = states[:, :, num_label_states:]
+
+ unique_y, unique_idx = unique
+ mul_reduce = _sum_states(unique_idx, label_states)
+
+ num_frames = states.shape[0]
+ batch_size = states.shape[1]
+ num_states = num_label_states - 1
+ batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
+ batch_state_major = array_ops.reshape(
+ batch_state_major, [batch_size * num_states, num_frames])
+ batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
+ indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
+ indices = array_ops.reshape(indices, [-1, 1])
+ scatter = array_ops.scatter_nd(
+ indices=indices,
+ updates=batch_state_major,
+ shape=[batch_size * num_labels, num_frames])
+ scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
+ scatter = array_ops.where(
+ math_ops.equal(scatter, 0.0),
+ array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)),
+ scatter)
+ label_olabels = array_ops.transpose(scatter, [2, 0, 1])
+ label_olabels = label_olabels[:, :, 1:]
+
+ blank_olabels = math_ops.reduce_logsumexp(
+ blank_states, axis=2, keepdims=True)
+
+ return array_ops.concat([blank_olabels, label_olabels], axis=-1)
+
+
+def ctc_loss_and_grad(logits, labels, label_length, logit_length, unique=None):
+ """Computes the CTC loss and gradients.
+
+ Most users will want fwd_bwd.ctc_loss
+
+ This function returns the computed gradient, it does not have a gradient
+ of its own defined.
+
+ Args:
+ logits: tensor of shape [frames, batch_size, num_labels]
+ labels: tensor of shape [batch_size, max_label_seq_length]
+ label_length: tensor of shape [batch_size]
+ Length of reference label sequence in labels.
+ logit_length: tensor of shape [batch_size]
+ Length of input sequence in logits.
+ unique: (optional) unique label indices as computed by unique(labels)
+ If supplied, enables an implementation that is faster and more memory
+ efficient on TPU.
+
+ Returns:
+ loss: tensor of shape [batch_size]
+ gradient: tensor of shape [frames, batch_size, num_labels]
+ """
+
+ num_labels = _get_dim(logits, 2)
+ max_label_seq_length = _get_dim(labels, 1)
+
+ ilabel_log_probs = nn_ops.log_softmax(logits)
+ state_log_probs = _ilabel_to_state(labels, num_labels, ilabel_log_probs)
+ state_trans_probs = _ctc_state_trans(labels)
+ initial_state_log_probs, final_state_log_probs = ctc_state_log_probs(
+ label_length, max_label_seq_length)
+ fwd_bwd_log_probs, log_likelihood = _forward_backward_log(
+ state_trans_log_probs=math_ops.log(state_trans_probs),
+ initial_state_log_probs=initial_state_log_probs,
+ final_state_log_probs=final_state_log_probs,
+ observed_log_probs=state_log_probs,
+ sequence_length=logit_length)
+
+ if unique:
+ olabel_log_probs = _state_to_olabel_unique(
+ labels, num_labels, fwd_bwd_log_probs, unique)
+ else:
+ olabel_log_probs = _state_to_olabel(labels, num_labels, fwd_bwd_log_probs)
+
+ grad = math_ops.exp(ilabel_log_probs) - math_ops.exp(olabel_log_probs)
+ loss = -log_likelihood
+ return loss, grad
+
+
+def _ctc_loss_grad(op, grad_loss, _):
+ grad = op.outputs[1]
+ grad = [array_ops.reshape(grad_loss, [1, -1, 1]) * grad]
+ grad += [None] * (len(op.inputs) - len(grad))
+ return grad
+
+
+def _ctc_loss_shape(op):
+ return [op.inputs[2].get_shape(), op.inputs[0].get_shape()]
+
+
+@tf_export("nn.ctc_loss", v1=["nn.ctc_loss_v2"])
+def ctc_loss_v2(labels, logits, label_length, logit_length,
+ logits_time_major=True, unique=None,
+ blank_index=None, name=None):
+ """Computes CTC (Connectionist Temporal Classification) loss.
+
+ This op implements the CTC loss as presented in the article:
+
+ [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
+ Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
+ with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
+ pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
+
+ Notes:
+ - Same as the "Classic CTC" in TensorFlow 1.x's tf.nn.ctc_loss setting of
+ preprocess_collapse_repeated=False, ctc_merge_repeated=True
+ - Labels may be supplied as either a dense, zero-padded tensor with a
+ vector of label sequence lengths OR as a SparseTensor.
+ - On TPU and GPU:
+ - Only dense padded labels are supported.
+ - On CPU:
+ - Caller may use SparseTensor or dense padded labels but calling with
+ a SparseTensor will be significantly faster.
+ - Default blank label is 0 rather num_classes - 1, unless overridden by
+ blank_index.
+
+ Args:
+ labels: tensor of shape [batch_size, max_label_seq_length] or SparseTensor
+ logits: tensor of shape [frames, batch_size, num_labels],
+ if logits_time_major == False, shape is [batch_size, frames, num_labels].
+ label_length: tensor of shape [batch_size], None if labels is SparseTensor
+ Length of reference label sequence in labels.
+ logit_length: tensor of shape [batch_size]
+ Length of input sequence in logits.
+ logits_time_major: (optional) If True (default), logits is shaped
+ [time, batch, logits]. If False, shape is [batch, time, logits]
+ unique: (optional) Unique label indices as computed by
+ ctc_unique_labels(labels). If supplied, enable a faster, memory
+ efficient implementation on TPU.
+ blank_index: (optional) Set the class index to use for the blank label.
+ Negative values will start from num_classes, ie, -1 will reproduce the
+ ctc_loss behavior of using num_classes - 1 for the blank symbol.
+ There is some memory/performance overhead to switching from the default
+ of 0 as an additional shifted copy of the logits may be created.
+ name: A name for this `Op`. Defaults to "ctc_loss_dense".
+
+ Returns:
+ loss: tensor of shape [batch_size], negative log probabilities.
+ """
+ if isinstance(labels, sparse_tensor.SparseTensor):
+ if blank_index is None:
+ raise ValueError(
+ "blank_index must be given when using SparseTensor labels.")
+
+ if blank_index < 0:
+ blank_index += _get_dim(logits, 2)
+
+ if blank_index != _get_dim(logits, 2) - 1:
+ logits = array_ops.concat([
+ logits[:, :, :blank_index],
+ logits[:, :, blank_index+1:],
+ logits[:, :, blank_index:blank_index+1],
+ ], axis=2)
+ labels = sparse_tensor.SparseTensor(
+ labels.indices,
+ array_ops.where(labels.values < blank_index,
+ labels.values,
+ labels.values - 1),
+ labels.dense_shape)
+
+ return ctc_loss(labels=labels,
+ inputs=logits,
+ sequence_length=logit_length,
+ time_major=logits_time_major)
+
+ if blank_index is None:
+ blank_index = 0
+
+ return ctc_loss_dense(labels=labels,
+ logits=logits,
+ label_length=label_length,
+ logit_length=logit_length,
+ logits_time_major=logits_time_major,
+ unique=unique,
+ blank_index=blank_index,
+ name=name)
+
+
+def ctc_loss_dense(labels, logits, label_length, logit_length,
+ logits_time_major=True, unique=None,
+ blank_index=0, name=None):
+ """Computes CTC (Connectionist Temporal Classification) loss.
+
+ This op implements the CTC loss as presented in the article:
+
+ [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
+ Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
+ with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
+ pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
+
+ Using the batched forward backward algorithm described in:
+
+ [Sim, K. C., Narayanan, A., Bagby, T., Sainath, T. N., & Bacchiani, M.
+ Improving the efficiency of forward-backward algorithm using batched
+ computation in TensorFlow.
+ Automatic Speech Recognition and Understanding Workshop (ASRU),
+ 2017 IEEE (pp. 258-264).
+ ](https://ieeexplore.ieee.org/iel7/8260578/8268903/08268944.pdf)
+
+ Notes:
+ Significant differences from tf.nn.ctc_loss:
+ Supports GPU and TPU (tf.nn.ctc_loss supports CPU only):
+ For batched operations, GPU and TPU are significantly faster than using
+ ctc_loss on CPU.
+ This implementation runs on CPU, but significantly slower than ctc_loss.
+ Blank label is 0 rather num_classes - 1, unless overridden by blank_index.
+ Logits and labels are dense arrays with padding rather than SparseTensor.
+ The only mode supported is the same as:
+ preprocess_collapse_repeated=False, ctc_merge_repeated=True
+ To collapse labels, the caller can preprocess label sequence first.
+
+ The dense implementation supports both CPU, GPU and TPU. A fast path is
+ provided that significantly improves memory use for large vocabulary if the
+ caller preprocesses label sequences to get unique label indices on the CPU
+ (eg. in the data input pipeline) using ctc_ops.unique and simplies this in
+ the optional "unique" kwarg. This is especially useful for TPU and GPU but
+ also works with if used on CPU.
+
+ Args:
+ labels: tensor of shape [batch_size, max_label_seq_length]
+ logits: tensor of shape [frames, batch_size, num_labels],
+ if logits_time_major == False, shape is [batch_size, frames, num_labels].
+ label_length: tensor of shape [batch_size]
+ Length of reference label sequence in labels.
+ logit_length: tensor of shape [batch_size]
+ Length of input sequence in logits.
+ logits_time_major: (optional) If True (default), logits is shaped
+ [time, batch, logits]. If False, shape is [batch, time, logits]
+ unique: (optional) Unique label indices as computed by unique(labels).
+ If supplied, enable a faster, memory efficient implementation on TPU.
+ blank_index: (optional) Set the class index to use for the blank label.
+ Negative values will start from num_classes, ie, -1 will reproduce the
+ ctc_loss behavior of using num_classes - 1 for the blank symbol.
+ There is some memory/performance overhead to switching from the default
+ of 0 as an additional shifted copy of the logits may be created.
+ name: A name for this `Op`. Defaults to "ctc_loss_dense".
+
+ Returns:
+ loss: tensor of shape [batch_size], negative log probabilities.
+ """
+
+ with ops.name_scope(name, "ctc_loss_dense",
+ [logits, labels, label_length, logit_length]):
+ logits = ops.convert_to_tensor(logits, name="logits")
+ labels = ops.convert_to_tensor(labels, name="labels")
+ label_length = ops.convert_to_tensor(label_length, name="label_length")
+ logit_length = ops.convert_to_tensor(logit_length, name="logit_length")
+
+ if not logits_time_major:
+ logits = array_ops.transpose(logits, perm=[1, 0, 2])
+
+ if blank_index != 0:
+ if blank_index < 0:
+ blank_index += _get_dim(logits, 2)
+ logits = array_ops.concat([
+ logits[:, :, blank_index:blank_index+1],
+ logits[:, :, :blank_index],
+ logits[:, :, blank_index+1:],
+ ], axis=2)
+ labels = array_ops.where(labels < blank_index, labels + 1, labels)
+
+ args = [logits, labels, label_length, logit_length]
+
+ if unique:
+ unique_y, unique_idx = unique
+ args.extend([unique_y, unique_idx])
+
+ # TODO(tombagby): Update to tfe.defun
+ @function.Defun(*[x.dtype for x in args],
+ python_grad_func=_ctc_loss_grad,
+ shape_func=_ctc_loss_shape)
+ def compute_ctc_loss(logits_t, labels_t, label_length_t, logit_length_t,
+ *unique_t):
+ """Compute CTC loss."""
+ logits_t.set_shape(logits.shape)
+ labels_t.set_shape(labels.shape)
+ label_length_t.set_shape(label_length.shape)
+ logit_length_t.set_shape(logit_length.shape)
+ kwargs = dict(
+ logits=logits_t,
+ labels=labels_t,
+ label_length=label_length_t,
+ logit_length=logit_length_t)
+ if unique_t:
+ kwargs["unique"] = unique_t
+ return ctc_loss_and_grad(**kwargs)
+
+ return compute_ctc_loss(*args)[0]
+
+
+@tf_export("nn.collapse_repeated")
+def collapse_repeated(labels, seq_length, name=None):
+ """Merge repeated labels into single labels.
+
+ Args:
+ labels: Tensor of shape (batch, max value in seq_length)
+ seq_length: Tensor of shape (batch), sequence length of each batch element.
+ name: A name for this `Op`. Defaults to "collapse_repeated_labels".
+
+ Returns:
+ tuple of Tensor of shape (batch, max_seq_length) with repeated labels
+ collapsed and padded to max_seq_length, eg:
+ [[A, A, B, B, A],
+ [A, B, C, D, E]] => [[A, B, A, 0, 0],
+ [A, B, C, D, E]]
+ and int tensor of shape [batch] with new sequence lengths.
+ """
+
+ with ops.name_scope(name, "collapse_repeated_labels",
+ [labels, seq_length]):
+ labels = ops.convert_to_tensor(labels, name="labels")
+ seq_length = ops.convert_to_tensor(seq_length, name="seq_length")
+
+ # Mask labels that don't equal previous label.
+ label_mask = array_ops.concat(
+ [array_ops.ones_like(labels[:, :1], dtypes.bool),
+ math_ops.not_equal(labels[:, 1:], labels[:, :-1])],
+ axis=1)
+
+ # Filter labels that aren't in the original sequence.
+ maxlen = _get_dim(labels, 1)
+ seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
+ label_mask = math_ops.logical_and(label_mask, seq_mask)
+
+ # Count masks for new sequence lengths.
+ new_seq_len = math_ops.reduce_sum(
+ math_ops.cast(label_mask, dtypes.int32), axis=1)
+
+ # Mask indexes based on sequence length mask.
+ new_maxlen = math_ops.reduce_max(new_seq_len)
+ idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)
+
+ # Flatten everything and mask out labels to keep and sparse indices.
+ flat_labels = array_ops.reshape(labels, [-1])
+ flat_label_mask = array_ops.reshape(label_mask, [-1])
+ flat_idx_mask = array_ops.reshape(idx_mask, [-1])
+ idx = math_ops.range(_get_dim(flat_idx_mask, 0))
+
+ # Scatter to flat shape.
+ flat = array_ops.scatter_nd(
+ indices=array_ops.expand_dims(
+ array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
+ updates=array_ops.boolean_mask(flat_labels, flat_label_mask),
+ shape=array_ops.shape(flat_idx_mask))
+
+ # Reshape back to square batch.
+ batch_size = _get_dim(labels, 0)
+ new_shape = [batch_size, new_maxlen]
+ return (array_ops.reshape(flat, new_shape),
+ math_ops.cast(new_seq_len, seq_length.dtype))
+
+
+def dense_labels_to_sparse(dense, length):
+ """Convert dense labels with sequence lengths to sparse tensor.
+
+ Args:
+ dense: tensor of shape [batch, max_length]
+ length: int tensor of shape [batch]
+ The length of each sequence in dense.
+
+ Returns:
+ tf.SparseTensor with values only for the valid elements of sequences.
+ """
+
+ flat_values = array_ops.reshape(dense, [-1])
+ flat_indices = math_ops.range(
+ array_ops.shape(flat_values, out_type=dtypes.int64)[0])
+ mask = array_ops.sequence_mask(length, maxlen=array_ops.shape(dense)[1])
+ flat_mask = array_ops.reshape(mask, [-1])
+ indices = array_ops.expand_dims(
+ array_ops.boolean_mask(flat_indices, flat_mask), 1)
+ values = array_ops.boolean_mask(flat_values, flat_mask)
+ sparse = sparse_tensor.SparseTensor(
+ indices=indices, values=math_ops.cast(values, dtypes.int32),
+ dense_shape=array_ops.shape(flat_values, out_type=dtypes.int64))
+ reshaped = sparse_ops.sparse_reshape(sparse, array_ops.shape(dense))
+ max_length = math_ops.reduce_max(length)
+ return sparse_tensor.SparseTensor(
+ indices=reshaped.indices,
+ values=reshaped.values,
+ dense_shape=[
+ math_ops.cast(reshaped.dense_shape[0], dtypes.int64),
+ math_ops.cast(max_length, dtypes.int64)])
+
+
+@tf_export("nn.ctc_unique_labels")
+def ctc_unique_labels(labels, name=None):
+ """Get unique labels and indices for batched labels for tf.nn.ctc_loss.
+
+ For use with tf.nn.ctc_loss_v2 optional argument `unique`: This op can be
+ used to preprocess labels in input pipeline to for better speed/memory use
+ computing the ctc loss on TPU.
+
+ Example:
+ ctc_unique_labels([[3, 4, 4, 3]]) ->
+ unique labels padded with 0: [[3, 4, 0, 0]]
+ indices of original labels in unique: [0, 1, 1, 0]
+
+ Args:
+ labels: tensor of shape [batch_size, max_label_length] padded with 0.
+ name: A name for this `Op`. Defaults to "ctc_unique_labels".
+
+ Returns:
+ tuple of
+ - unique labels, tensor of shape `[batch_size, max_label_length]`
+ - indices into unique labels, shape `[batch_size, max_label_length]`
+ """
+
+ with ops.name_scope(name, "ctc_unique_labels", [labels]):
+ labels = ops.convert_to_tensor(labels, name="labels")
+ def _unique(x):
+ u = array_ops.unique(x)
+ y = array_ops.pad(
+ u.y, [[0, _get_dim(u.idx, 0) - _get_dim(u.y, 0)]])
+ y = math_ops.cast(y, dtypes.int64)
+ return [y, u.idx]
+ return functional_ops.map_fn(
+ _unique, labels, dtype=[dtypes.int64, dtypes.int32])
+
+
+def _sum_states(idx, states):
+ """Take logsumexp for each unique state out of all label states.
+
+ Args:
+ idx: tensor of shape [batch, label_length]
+ For each sequence, indices into a set of unique labels as computed by
+ calling unique.
+ states: tensor of shape [frames, batch, label_length]
+ Log probabilities for each label state.
+
+ Returns:
+ tensor of shape [frames, batch_size, label_length], log probabilites summed
+ for each unique label of the sequence.
+ """
+
+ with ops.name_scope("sum_states"):
+ idx = ops.convert_to_tensor(idx, name="idx")
+ num_states = _get_dim(states, 2)
+ states = array_ops.expand_dims(states, axis=2)
+ one_hot = array_ops.one_hot(
+ idx, depth=num_states, on_value=0.0, off_value=math_ops.log(0.0),
+ axis=1)
+ return math_ops.reduce_logsumexp(states + one_hot, axis=-1)
+
+
+def _forward_backward_log(state_trans_log_probs, initial_state_log_probs,
+ final_state_log_probs, observed_log_probs,
+ sequence_length):
+ """Forward-backward algorithm computed in log domain.
+
+ Args:
+ state_trans_log_probs: tensor of shape [states, states] or
+ if different transition matrix per batch [batch_size, states, states]
+ initial_state_log_probs: tensor of shape [batch_size, states]
+ final_state_log_probs: tensor of shape [batch_size, states]
+ observed_log_probs: tensor of shape [frames, batch_size, states]
+ sequence_length: tensor of shape [batch_size]
+
+ Returns:
+ forward backward log probabilites: tensor of shape [frames, batch, states]
+ log_likelihood: tensor of shape [batch_size]
+
+ Raises:
+ ValueError: If state_trans_log_probs has unknown or incorrect rank.
+ """
+
+ if state_trans_log_probs.shape.ndims == 2:
+ perm = [1, 0]
+ elif state_trans_log_probs.shape.ndims == 3:
+ perm = [0, 2, 1]
+ else:
+ raise ValueError(
+ "state_trans_log_probs rank must be known and == 2 or 3, is: %s" %
+ state_trans_log_probs.shape.ndims)
+
+ bwd_state_trans_log_probs = array_ops.transpose(state_trans_log_probs, perm)
+ batch_size = _get_dim(observed_log_probs, 1)
+
+ def _forward(state_log_prob, obs_log_prob):
+ state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast.
+ state_log_prob += state_trans_log_probs
+ state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
+ state_log_prob += obs_log_prob
+ log_prob_sum = math_ops.reduce_logsumexp(
+ state_log_prob, axis=-1, keepdims=True)
+ state_log_prob -= log_prob_sum
+ return state_log_prob
+
+ fwd = _scan(_forward, observed_log_probs, initial_state_log_probs,
+ inclusive=True)
+
+ def _backward(accs, elems):
+ """Calculate log probs and cumulative sum masked for sequence length."""
+ state_log_prob, cum_log_sum = accs
+ obs_log_prob, mask = elems
+ state_log_prob += obs_log_prob
+ state_log_prob = array_ops.expand_dims(state_log_prob, axis=1) # Broadcast.
+ state_log_prob += bwd_state_trans_log_probs
+ state_log_prob = math_ops.reduce_logsumexp(state_log_prob, axis=-1)
+
+ log_prob_sum = math_ops.reduce_logsumexp(
+ state_log_prob, axis=-1, keepdims=True)
+ state_log_prob -= log_prob_sum
+
+ cum_log_sum += array_ops.squeeze(log_prob_sum) * mask
+ batched_mask = array_ops.expand_dims(mask, axis=1)
+ out = state_log_prob * batched_mask
+ out += final_state_log_probs * (1.0 - batched_mask)
+ return out, cum_log_sum
+
+ zero_log_sum = array_ops.zeros([batch_size])
+ maxlen = _get_dim(observed_log_probs, 0)
+ mask = array_ops.sequence_mask(sequence_length, maxlen, dtypes.float32)
+ mask = array_ops.transpose(mask, perm=[1, 0])
+
+ bwd, cum_log_sum = _scan(_backward, (observed_log_probs, mask),
+ (final_state_log_probs, zero_log_sum),
+ reverse=True, inclusive=True)
+
+ fwd_bwd_log_probs = fwd[1:] + bwd[1:]
+ fwd_bwd_log_probs_sum = math_ops.reduce_logsumexp(
+ fwd_bwd_log_probs, axis=2, keepdims=True)
+ fwd_bwd_log_probs -= fwd_bwd_log_probs_sum
+ fwd_bwd_log_probs += math_ops.log(array_ops.expand_dims(mask, axis=2))
+
+ log_likelihood = bwd[0, :, 0] + cum_log_sum[0]
+
+ return fwd_bwd_log_probs, log_likelihood
+
+
+# TODO(tombagby): This is currently faster for the ctc implementation than using
+# functional_ops.scan, but could be replaced by that or something similar if
+# things change.
+def _scan(fn, elems, initial, reverse=False, inclusive=False, final_only=False):
+ """Repeatedly applies callable `fn` to a sequence of elements.
+
+ Implemented by functional_ops.While, tpu friendly, no gradient.
+
+ This is similar to functional_ops.scan but significantly faster on tpu/gpu
+ for the forward backward use case.
+
+ Examples:
+ scan(lambda a, e: a + e, [1.0, 2.0, 3.0], 1.0) => [2.0, 3.0, 4.0]
+
+ Multiple accumulators:
+ scan(lambda a, e: (a[0] + e, a[1] * e), [1.0, 2.0, 3.0], (0.0, 1.0))
+
+ Multiple inputs:
+ scan(lambda a, e: a + (e[0] * e[1]), (elems1, elems2), 0.0)
+
+ Args:
+ fn: callable, fn(accumulators, element) return new accumulator values.
+ The (possibly nested) sequence of accumulators is the same as `initial`
+ and the return value must have the same structure.
+ elems: A (possibly nested) tensor which will be unpacked along the first
+ dimension. The resulting slices will be the second argument to fn. The
+ first dimension of all nested input tensors must be the same.
+ initial: A tensor or (possibly nested) sequence of tensors with initial
+ values for the accumulators.
+ reverse: (optional) True enables scan and output elems in reverse order.
+ inclusive: (optional) True includes the initial accumulator values in the
+ output. Length of output will be len(elem sequence) + 1. Not meaningful
+ if final_only is True.
+ final_only: (optional) When True, return only the final accumulated values,
+ not the concatenation of accumulated values for each input.
+
+ Returns:
+ A (possibly nested) sequence of tensors with the results of applying fn
+ to tensors unpacked from elems and previous accumulator values.
+ """
+
+ flat_elems = [ops.convert_to_tensor(x) for x in nest.flatten(elems)]
+ num_elems = array_ops.shape(flat_elems[0])[0]
+ pack_elems = lambda x: nest.pack_sequence_as(structure=elems, flat_sequence=x)
+ flat_initial = [ops.convert_to_tensor(x) for x in nest.flatten(initial)]
+ pack = lambda x: nest.pack_sequence_as(structure=initial, flat_sequence=x)
+ accum_dtypes = [x.dtype for x in flat_initial]
+ num_accums = len(flat_initial)
+
+ # Types for counter, [outputs], [accumulators] loop arguments.
+ if final_only:
+ loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes
+ else:
+ loop_dtypes = [dtypes.int32, dtypes.int32] + accum_dtypes + accum_dtypes
+
+ # TODO(tombagby): Update to tfe.defun
+ @function.Defun(*loop_dtypes)
+ def cond(i, num_elems, *args):
+ del args
+ return i >= 0 if reverse else i < num_elems
+
+ # The loop *args are [output tensors] + [accumulator tensors] which must
+ # be paired. Each output corresponds to one accumulator.
+ @function.Defun(*loop_dtypes)
+ def body(i, num_elems, *args):
+ """Loop body."""
+ i.set_shape([])
+ if final_only:
+ accum = args
+ else:
+ out, accum = args[:num_accums], args[num_accums:]
+ slices = [array_ops.gather(e, i) for e in flat_elems]
+ accum = fn(pack(accum), pack_elems(slices))
+ flat_accum = nest.flatten(accum)
+ if final_only:
+ new_out = []
+ else:
+ update_i = i + 1 if inclusive and not reverse else i
+ new_out = [inplace_ops.alias_inplace_update(x, update_i, y)
+ for x, y in zip(out, flat_accum)]
+ i = i - 1 if reverse else i + 1
+ return [i, num_elems] + new_out + flat_accum
+
+ init_i = (array_ops.shape(flat_elems[0])[0] - 1 if reverse
+ else constant_op.constant(0, dtype=dtypes.int32))
+ outputs = []
+ if not final_only:
+ num_outputs = array_ops.shape(flat_elems[0])[0] + (1 if inclusive else 0)
+ for initial_accum in flat_initial:
+ out_shape = array_ops.concat(
+ [[num_outputs], array_ops.shape(initial_accum)], 0)
+ out = inplace_ops.empty(out_shape, dtype=initial_accum.dtype, init=True)
+ if inclusive:
+ out = inplace_ops.alias_inplace_add(
+ out, init_i + (1 if reverse else 0), initial_accum)
+ outputs.append(out)
+ loop_in = [init_i, num_elems] + outputs + flat_initial
+ hostmem = [
+ i for i, x in enumerate(loop_in)
+ if x.dtype.base_dtype in (dtypes.int32, dtypes.int64)
+ ]
+
+ # TODO(tombagby): Update to while_v2.
+ loop_results = functional_ops.While(loop_in, cond, body, hostmem=hostmem)
+ out = loop_results[2:num_accums + 2]
+ return pack(out)
+
+
+def _get_dim(tensor, i):
+ """Get value of tensor shape[i] preferring static value if available."""
+ return tensor.shape[i].value or array_ops.shape(tensor)[i]
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 0fac799..bb08dba 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -79,7 +79,7 @@
shapes = [shapes]
shapes = [tensor_shape.as_shape(shape) for shape in shapes]
if not unknown_dim_allowed:
- if any([not shape.is_fully_defined() for shape in shapes]):
+ if any(not shape.is_fully_defined() for shape in shapes):
raise ValueError("All shapes must be fully defined: %s" % shapes)
if not unknown_rank_allowed:
if any([shape.dims is None for shape in shapes]):
@@ -198,11 +198,11 @@
raise TypeError("A list of queues expected")
dtypes = queues[0].dtypes
- if not all([dtypes == q.dtypes for q in queues[1:]]):
+ if not all(dtypes == q.dtypes for q in queues[1:]):
raise TypeError("Queues do not have matching component dtypes.")
names = queues[0].names
- if not all([names == q.names for q in queues[1:]]):
+ if not all(names == q.names for q in queues[1:]):
raise TypeError("Queues do not have matching component names.")
queue_shapes = [q.shapes for q in queues]
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 53c0709..c8f5cb8 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -539,7 +539,7 @@
return consumers
-@tf_export("gradients")
+@tf_export(v1=["gradients"])
def gradients(ys,
xs,
grad_ys=None,
@@ -655,6 +655,119 @@
unconnected_gradients)
+@tf_export("gradients", v1=[])
+def gradients_v2(ys, # pylint: disable=invalid-name
+ xs,
+ grad_ys=None,
+ name="gradients",
+ gate_gradients=False,
+ aggregation_method=None,
+ stop_gradients=None,
+ unconnected_gradients=UnconnectedGradients.NONE):
+ """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
+
+ `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
+ is a list of `Tensor`, holding the gradients received by the
+ `ys`. The list must be the same length as `ys`.
+
+ `gradients()` adds ops to the graph to output the derivatives of `ys` with
+ respect to `xs`. It returns a list of `Tensor` of length `len(xs)` where
+ each tensor is the `sum(dy/dx)` for y in `ys`.
+
+ `grad_ys` is a list of tensors of the same length as `ys` that holds
+ the initial gradients for each y in `ys`. When `grad_ys` is None,
+ we fill in a tensor of '1's of the shape of y for each y in `ys`. A
+ user can provide their own initial `grad_ys` to compute the
+ derivatives using a different initial gradient for each y (e.g., if
+ one wanted to weight the gradient differently for each value in
+ each y).
+
+ `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
+ with respect to all `xs`. These tensors will not be backpropagated through,
+ as though they had been explicitly disconnected using `stop_gradient`. Among
+ other things, this allows computation of partial derivatives as opposed to
+ total derivatives. For example:
+
+ ```python
+ a = tf.constant(0.)
+ b = 2 * a
+ g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])
+ ```
+
+ Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
+ total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
+ influence of `a` on `b` and evaluate to `[3.0, 1.0]`. Note that the above is
+ equivalent to:
+
+ ```python
+ a = tf.stop_gradient(tf.constant(0.))
+ b = tf.stop_gradient(2 * a)
+ g = tf.gradients(a + b, [a, b])
+ ```
+
+ `stop_gradients` provides a way of stopping gradient after the graph has
+ already been constructed, as compared to `tf.stop_gradient` which is used
+ during graph construction. When the two approaches are combined,
+ backpropagation stops at both `tf.stop_gradient` nodes and nodes in
+ `stop_gradients`, whichever is encountered first.
+
+ All integer tensors are considered constant with respect to all `xs`, as if
+ they were included in `stop_gradients`.
+
+ `unconnected_gradients` determines the value returned for each x in xs if it
+ is unconnected in the graph to ys. By default this is None to safeguard
+ against errors. MAthematically these gradients are zero which can be requested
+ using the `'zero'` option. `tf.UnconnectedGradients` provides the
+ following options and behaviors:
+
+ ```python
+ a = tf.ones([1, 2])
+ b = tf.ones([3, 1])
+ g1 = tf.gradients([b], [a], unnconnected_gradients='none')
+ sess.run(g1) # [None]
+
+ g2 = tf.gradients([b], [a], unconnected_gradients='zero')
+ sess.run(g2) # [array([[0., 0.]], dtype=float32)]
+ ```
+
+
+ Args:
+ ys: A `Tensor` or list of tensors to be differentiated.
+ xs: A `Tensor` or list of tensors to be used for differentiation.
+ grad_ys: Optional. A `Tensor` or list of tensors the same size as
+ `ys` and holding the gradients computed for each y in `ys`.
+ name: Optional name to use for grouping all the gradient ops together.
+ defaults to 'gradients'.
+ gate_gradients: If True, add a tuple around the gradients returned
+ for an operations. This avoids some race conditions.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Accepted values are constants defined in the class `AggregationMethod`.
+ stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
+ through.
+ unconnected_gradients: Optional. Specifies the gradient value returned when
+ the given input tensors are unconnected. Accepted values are constants
+ defined in the class `tf.UnconnectedGradients` and the default value is
+ `none`.
+
+ Returns:
+ A list of `sum(dy/dx)` for each x in `xs`.
+
+ Raises:
+ LookupError: if one of the operations between `x` and `y` does not
+ have a registered gradient function.
+ ValueError: if the arguments are invalid.
+ RuntimeError: if called in Eager mode.
+
+ """
+ # Creating the gradient graph for control flow mutates Operations.
+ # _mutation_lock ensures a Session.run call cannot occur between creating and
+ # mutating new ops.
+ with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
+ return _GradientsHelper(ys, xs, grad_ys, name, True, gate_gradients,
+ aggregation_method, stop_gradients,
+ unconnected_gradients)
+
+
def _GradientsHelper(ys,
xs,
grad_ys=None,
@@ -895,7 +1008,7 @@
if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
return True
if out_grad and isinstance(out_grad, collections.Sequence):
- if any([g is not None for g in out_grad]):
+ if any(g is not None for g in out_grad):
return True
return False
@@ -1110,11 +1223,11 @@
assert control_flow_util.IsLoopSwitch(op)
continue
# Grads have to be Tensors or IndexedSlices
- if (isinstance(out_grad, collections.Sequence) and not all([
+ if (isinstance(out_grad, collections.Sequence) and not all(
isinstance(g, (ops.Tensor, ops.IndexedSlices))
for g in out_grad
if g is not None
- ])):
+ )):
raise TypeError("gradients have to be either all Tensors "
"or all IndexedSlices")
# Aggregate multiple gradients, and convert [] to None.
@@ -1122,7 +1235,7 @@
if len(out_grad) < 2:
used = "nop"
out_grads[i] = out_grad[0]
- elif all([isinstance(g, ops.Tensor) for g in out_grad if g is not None]):
+ elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None):
tensor_shape = _AccumulatorShape(out_grad)
if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
and len(out_grad) > 2 and tensor_shape.is_fully_defined()):
@@ -1239,7 +1352,7 @@
return gradients(elemwise_products, xs)
-@tf_export("hessians")
+@tf_export(v1=["hessians"])
def hessians(ys,
xs,
name="hessians",
@@ -1304,3 +1417,16 @@
array_ops.concat((_shape, _shape), 0))
hessians.append(_reshaped_hessian)
return hessians
+
+
+@tf_export("hessians", v1=[])
+def HessiansV2(ys,
+ xs,
+ gate_gradients=False,
+ aggregation_method=None,
+ name="hessians"):
+ return hessians(ys, xs, name=name, gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method)
+
+
+HessiansV2.__doc__ = hessians.__doc__
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 4d1357a..229393c 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -24,6 +24,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
@@ -1759,7 +1760,7 @@
orig_dtype)
-@tf_export('image.is_jpeg')
+@tf_export('io.is_jpeg', 'image.is_jpeg', v1=['io.is_jpeg', 'image.is_jpeg'])
def is_jpeg(contents, name=None):
r"""Convenience function to check if the 'contents' encodes a JPEG image.
@@ -1794,8 +1795,28 @@
substr = string_ops.substr(contents, 0, 3)
return math_ops.equal(substr, b'\211PN', name=name)
+tf_export('io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg',
+ v1=['io.decode_and_crop_jpeg', 'image.decode_and_crop_jpeg'])(
+ gen_image_ops.decode_and_crop_jpeg)
-@tf_export('image.decode_image')
+tf_export('io.decode_bmp', 'image.decode_bmp',
+ v1=['io.decode_bmp', 'image.decode_bmp'])(gen_image_ops.decode_bmp)
+tf_export('io.decode_gif', 'image.decode_gif',
+ v1=['io.decode_gif', 'image.decode_gif'])(gen_image_ops.decode_gif)
+tf_export('io.decode_jpeg', 'image.decode_jpeg',
+ v1=['io.decode_jpeg', 'image.decode_jpeg'])(gen_image_ops.decode_jpeg)
+tf_export('io.decode_png', 'image.decode_png',
+ v1=['io.decode_png', 'image.decode_png'])(gen_image_ops.decode_png)
+
+tf_export('io.encode_jpeg', 'image.encode_jpeg',
+ v1=['io.encode_jpeg', 'image.encode_jpeg'])(gen_image_ops.encode_jpeg)
+tf_export('io.extract_jpeg_shape', 'image.extract_jpeg_shape',
+ v1=['io.extract_jpeg_shape', 'image.extract_jpeg_shape'])(
+ gen_image_ops.extract_jpeg_shape)
+
+
+@tf_export('io.decode_image', 'image.decode_image',
+ v1=['io.decode_image', 'image.decode_image'])
def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None):
"""Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`,
and `decode_png`.
@@ -1965,7 +1986,114 @@
return tot_var
-@tf_export('image.sample_distorted_bounding_box')
+@tf_export('image.sample_distorted_bounding_box', v1=[])
+def sample_distorted_bounding_box_v2(image_size,
+ bounding_boxes,
+ seed=0,
+ min_object_covered=0.1,
+ aspect_ratio_range=None,
+ area_range=None,
+ max_attempts=None,
+ use_image_if_no_bounding_boxes=None,
+ name=None):
+ """Generate a single randomly distorted bounding box for an image.
+
+ Bounding box annotations are often supplied in addition to ground-truth labels
+ in image recognition or object localization tasks. A common technique for
+ training such a system is to randomly distort an image while preserving
+ its content, i.e. *data augmentation*. This Op outputs a randomly distorted
+ localization of an object, i.e. bounding box, given an `image_size`,
+ `bounding_boxes` and a series of constraints.
+
+ The output of this Op is a single bounding box that may be used to crop the
+ original image. The output is returned as 3 tensors: `begin`, `size` and
+ `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
+ image. The latter may be supplied to `tf.image.draw_bounding_boxes` to
+ visualize what the bounding box looks like.
+
+ Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`.
+ The bounding box coordinates are floats in `[0.0, 1.0]` relative to the width
+ and height of the underlying image.
+
+ For example,
+
+ ```python
+ # Generate a single distorted bounding box.
+ begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
+ tf.shape(image),
+ bounding_boxes=bounding_boxes,
+ min_object_covered=0.1)
+
+ # Draw the bounding box in an image summary.
+ image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
+ bbox_for_draw)
+ tf.summary.image('images_with_box', image_with_box)
+
+ # Employ the bounding box to distort the image.
+ distorted_image = tf.slice(image, begin, size)
+ ```
+
+ Note that if no bounding box information is available, setting
+ `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
+ bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
+ false and no bounding boxes are supplied, an error is raised.
+
+ Args:
+ image_size: A `Tensor`. Must be one of the following types: `uint8`, `int8`,
+ `int16`, `int32`, `int64`.
+ 1-D, containing `[height, width, channels]`.
+ bounding_boxes: A `Tensor` of type `float32`.
+ 3-D with shape `[batch, N, 4]` describing the N bounding boxes
+ associated with the image.
+ seed: An optional `int`. Defaults to `0`.
+ If either `seed` or `seed2` are set to non-zero, the random number
+ generator is seeded by the given `seed`. Otherwise, it is seeded by a
+ random seed.
+ min_object_covered: A Tensor of type `float32`. Defaults to `0.1`.
+ The cropped area of the image must contain at least this
+ fraction of any bounding box supplied. The value of this parameter should
+ be non-negative. In the case of 0, the cropped area does not need to
+ overlap any of the bounding boxes supplied.
+ aspect_ratio_range: An optional list of `floats`. Defaults to `[0.75,
+ 1.33]`.
+ The cropped area of the image must have an aspect `ratio =
+ width / height` within this range.
+ area_range: An optional list of `floats`. Defaults to `[0.05, 1]`.
+ The cropped area of the image must contain a fraction of the
+ supplied image within this range.
+ max_attempts: An optional `int`. Defaults to `100`.
+ Number of attempts at generating a cropped region of the image
+ of the specified constraints. After `max_attempts` failures, return the
+ entire image.
+ use_image_if_no_bounding_boxes: An optional `bool`. Defaults to `False`.
+ Controls behavior if no bounding boxes supplied.
+ If true, assume an implicit bounding box covering the whole input. If
+ false, raise an error.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tuple of `Tensor` objects (begin, size, bboxes).
+
+ begin: A `Tensor`. Has the same type as `image_size`. 1-D, containing
+ `[offset_height, offset_width, 0]`. Provide as input to
+ `tf.slice`.
+ size: A `Tensor`. Has the same type as `image_size`. 1-D, containing
+ `[target_height, target_width, -1]`. Provide as input to
+ `tf.slice`.
+ bboxes: A `Tensor` of type `float32`. 3-D with shape `[1, 1, 4]` containing
+ the distorted bounding box.
+ Provide as input to `tf.image.draw_bounding_boxes`.
+ """
+ seed1, seed2 = random_seed.get_seed(seed) if seed else (0, 0)
+ return sample_distorted_bounding_box(
+ image_size, bounding_boxes, seed1, seed2, min_object_covered,
+ aspect_ratio_range, area_range, max_attempts,
+ use_image_if_no_bounding_boxes, name)
+
+
+@tf_export(v1=['image.sample_distorted_bounding_box'])
+@deprecation.deprecated(date=None, instructions='`seed2` arg is deprecated.'
+ 'Use sample_distorted_bounding_box_v2 instead.')
def sample_distorted_bounding_box(image_size,
bounding_boxes,
seed=None,
@@ -2861,3 +2989,72 @@
'instead.'))
tf_export(v1=['image.resize_nearest_neighbor'])(
resize_nearest_neighbor_deprecation(gen_image_ops.resize_nearest_neighbor))
+
+
+@tf_export('image.crop_and_resize', v1=[])
+def crop_and_resize_v2(
+ image,
+ boxes,
+ box_indices,
+ crop_size,
+ method='bilinear',
+ extrapolation_value=0,
+ name=None):
+ """Extracts crops from the input image tensor and resizes them.
+
+ Extracts crops from the input image tensor and resizes them using bilinear
+ sampling or nearest neighbor sampling (possibly with aspect ratio change) to a
+ common output size specified by `crop_size`. This is more general than the
+ `crop_to_bounding_box` op which extracts a fixed size slice from the input
+ image and does not allow resizing or aspect ratio change.
+
+ Returns a tensor with `crops` from the input `image` at positions defined at
+ the bounding box locations in `boxes`. The cropped boxes are all resized (with
+ bilinear or nearest neighbor interpolation) to a fixed
+ `size = [crop_height, crop_width]`. The result is a 4-D tensor
+ `[num_boxes, crop_height, crop_width, depth]`. The resizing is corner aligned.
+ In particular, if `boxes = [[0, 0, 1, 1]]`, the method will give identical
+ results to using `tf.image.resize_bilinear()` or
+ `tf.image.resize_nearest_neighbor()`(depends on the `method` argument) with
+ `align_corners=True`.
+
+ Args:
+ image: A 4-D tensor of shape `[batch, image_height, image_width, depth]`.
+ Both `image_height` and `image_width` need to be positive.
+ boxes: A 2-D tensor of shape `[num_boxes, 4]`. The `i`-th row of the tensor
+ specifies the coordinates of a box in the `box_ind[i]` image and is
+ specified in normalized coordinates `[y1, x1, y2, x2]`. A normalized
+ coordinate value of `y` is mapped to the image coordinate at `y *
+ (image_height - 1)`, so as the `[0, 1]` interval of normalized image
+ height is mapped to `[0, image_height - 1]` in image height coordinates.
+ We do allow `y1` > `y2`, in which case the sampled crop is an up-down
+ flipped version of the original image. The width dimension is treated
+ similarly. Normalized coordinates outside the `[0, 1]` range are allowed,
+ in which case we use `extrapolation_value` to extrapolate the input image
+ values.
+ box_indices: A 1-D tensor of shape `[num_boxes]` with int32 values in `[0,
+ batch)`. The value of `box_ind[i]` specifies the image that the `i`-th box
+ refers to.
+ crop_size: A 1-D tensor of 2 elements, `size = [crop_height, crop_width]`.
+ All cropped image patches are resized to this size. The aspect ratio of
+ the image content is not preserved. Both `crop_height` and `crop_width`
+ need to be positive.
+ method: An optional string specifying the sampling method for resizing. It
+ can be either `"bilinear"` or `"nearest"` and default to `"bilinear"`.
+ Currently two sampling methods are supported: Bilinear and Nearest
+ Neighbor.
+ extrapolation_value: An optional `float`. Defaults to `0`. Value used for
+ extrapolation, when applicable.
+ name: A name for the operation (optional).
+
+ Returns:
+ A 4-D tensor of shape `[num_boxes, crop_height, crop_width, depth]`.
+ """
+ return gen_image_ops.crop_and_resize(
+ image, boxes, box_indices, crop_size, method, extrapolation_value, name)
+
+
+crop_and_resize_deprecation = deprecation.deprecated_args(
+ None, 'box_ind is deprecated, use box_indices instead', 'box_ind')
+tf_export(v1=['image.crop_and_resize'])(
+ crop_and_resize_deprecation(gen_image_ops.crop_and_resize))
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index de82f4f..71a574e 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -70,7 +70,8 @@
split2 = list(map(image_ops.hsv_to_rgb, split1))
join1 = array_ops.stack(split1)
join2 = array_ops.stack(split2)
- batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+ batch1, batch2, join1, join2 = self.evaluate(
+ [batch1, batch2, join1, join2])
# Verify that processing batch elements together is the same as separate
self.assertAllClose(batch1, join1)
@@ -109,7 +110,8 @@
split2 = list(map(image_ops.yiq_to_rgb, split1))
join1 = array_ops.stack(split1)
join2 = array_ops.stack(split2)
- batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+ batch1, batch2, join1, join2 = self.evaluate(
+ [batch1, batch2, join1, join2])
# Verify that processing batch elements together is the same as separate
self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4)
@@ -138,7 +140,8 @@
split2 = list(map(image_ops.yuv_to_rgb, split1))
join1 = array_ops.stack(split1)
join2 = array_ops.stack(split2)
- batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+ batch1, batch2, join1, join2 = self.evaluate(
+ [batch1, batch2, join1, join2])
# Verify that processing batch elements together is the same as separate
self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4)
@@ -2265,7 +2268,7 @@
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width], opt)
yshape = array_ops.shape(y)
- resized, newshape = sess.run([y, yshape])
+ resized, newshape = self.evaluate([y, yshape])
self.assertAllEqual(img_shape, newshape)
self.assertAllClose(resized, img_np, atol=1e-5)
@@ -2379,7 +2382,7 @@
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [height, width], opt)
yshape = array_ops.shape(y)
- resized, newshape = sess.run([y, yshape])
+ resized, newshape = self.evaluate([y, yshape])
self.assertAllEqual(img_shape, newshape)
self.assertAllClose(resized, img_np, atol=1e-5)
@@ -3066,7 +3069,7 @@
jpeg0 = io_ops.read_file(path)
image0 = image_ops.decode_jpeg(jpeg0)
image1 = image_ops.decode_jpeg(image_ops.encode_jpeg(image0))
- jpeg0, image0, image1 = sess.run([jpeg0, image0, image1])
+ jpeg0, image0, image1 = self.evaluate([jpeg0, image0, image1])
self.assertEqual(len(jpeg0), 3771)
self.assertEqual(image0.shape, (256, 128, 3))
self.assertLess(self.averageError(image0, image1), 1.4)
@@ -3083,7 +3086,7 @@
io_ops.read_file(rgb_path), channels=channels)
cmyk = image_ops.decode_jpeg(
io_ops.read_file(cmyk_path), channels=channels)
- rgb, cmyk = sess.run([rgb, cmyk])
+ rgb, cmyk = self.evaluate([rgb, cmyk])
self.assertEqual(rgb.shape, shape)
self.assertEqual(cmyk.shape, shape)
error = self.averageError(rgb, cmyk)
@@ -3112,7 +3115,7 @@
image2.get_shape().as_list())
# CropAndDecode should be equal to DecodeJpeg+Crop.
- image1_crop, image2 = sess.run([image1_crop, image2])
+ image1_crop, image2 = self.evaluate([image1_crop, image2])
self.assertAllEqual(image1_crop, image2)
def testCropAndDecodeJpegWithInvalidCropWindow(self):
@@ -3131,7 +3134,7 @@
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError,
lambda e: "Invalid JPEG data or crop window" in str(e)):
- sess.run(result)
+ self.evaluate(result)
def testSynthetic(self):
with self.test_session(use_gpu=True) as sess:
@@ -3141,7 +3144,8 @@
image1 = image_ops.decode_jpeg(jpeg0, dct_method="INTEGER_ACCURATE")
image2 = image_ops.decode_jpeg(
image_ops.encode_jpeg(image1), dct_method="INTEGER_ACCURATE")
- jpeg0, image0, image1, image2 = sess.run([jpeg0, image0, image1, image2])
+ jpeg0, image0, image1, image2 = self.evaluate(
+ [jpeg0, image0, image1, image2])
# The decoded-encoded image should be similar to the input
self.assertLess(self.averageError(image0, image1), 0.6)
@@ -3161,7 +3165,8 @@
image1 = image_ops.decode_jpeg(jpeg0, dct_method="INTEGER_FAST")
image2 = image_ops.decode_jpeg(
image_ops.encode_jpeg(image1), dct_method="INTEGER_FAST")
- jpeg0, image0, image1, image2 = sess.run([jpeg0, image0, image1, image2])
+ jpeg0, image0, image1, image2 = self.evaluate(
+ [jpeg0, image0, image1, image2])
# The decoded-encoded image should be similar to the input, but
# note this is worse than the slower algorithm because it is
@@ -3184,7 +3189,7 @@
jpeg0 = image_ops.encode_jpeg(image0)
image1 = image_ops.decode_jpeg(jpeg0, dct_method="INTEGER_FAST")
image2 = image_ops.decode_jpeg(jpeg0)
- image1, image2 = sess.run([image1, image2])
+ image1, image2 = self.evaluate([image1, image2])
# The images should be the same.
self.assertAllClose(image1, image2)
@@ -3230,7 +3235,7 @@
with self.test_session(use_gpu=True) as sess:
png0 = io_ops.read_file(prefix + filename)
image0 = image_ops.decode_png(png0, channels=channels)
- png0, image0 = sess.run([png0, image0])
+ png0, image0 = self.evaluate([png0, image0])
self.assertEqual(image0.shape, (26, 51, channels or channels_in))
if channels == channels_in:
image1 = image_ops.decode_png(image_ops.encode_png(image0))
@@ -3242,7 +3247,7 @@
image0 = constant_op.constant(_SimpleColorRamp())
png0 = image_ops.encode_png(image0, compression=7)
image1 = image_ops.decode_png(png0)
- png0, image0, image1 = sess.run([png0, image0, image1])
+ png0, image0, image1 = self.evaluate([png0, image0, image1])
# PNG is lossless
self.assertAllEqual(image0, image1)
@@ -3257,7 +3262,7 @@
image0 = constant_op.constant(_SimpleColorRamp(), dtype=dtypes.uint16)
png0 = image_ops.encode_png(image0, compression=7)
image1 = image_ops.decode_png(png0, dtype=dtypes.uint16)
- png0, image0, image1 = sess.run([png0, image0, image1])
+ png0, image0, image1 = self.evaluate([png0, image0, image1])
# PNG is lossless
self.assertAllEqual(image0, image1)
@@ -3273,7 +3278,7 @@
image0 = constant_op.constant(gray_alpha)
png0 = image_ops.encode_png(image0, compression=7)
image1 = image_ops.decode_png(png0)
- png0, image0, image1 = sess.run([png0, image0, image1])
+ png0, image0, image1 = self.evaluate([png0, image0, image1])
self.assertEqual(2, image0.shape[-1])
self.assertAllEqual(image0, image1)
@@ -3284,7 +3289,7 @@
image0 = constant_op.constant(gray_alpha, dtype=dtypes.uint16)
png0 = image_ops.encode_png(image0, compression=7)
image1 = image_ops.decode_png(png0, dtype=dtypes.uint16)
- png0, image0, image1 = sess.run([png0, image0, image1])
+ png0, image0, image1 = self.evaluate([png0, image0, image1])
self.assertEqual(2, image0.shape[-1])
self.assertAllEqual(image0, image1)
@@ -3310,7 +3315,7 @@
with self.test_session(use_gpu=True) as sess:
gif0 = io_ops.read_file(prefix + filename)
image0 = image_ops.decode_gif(gif0)
- gif0, image0 = sess.run([gif0, image0])
+ gif0, image0 = self.evaluate([gif0, image0])
self.assertEqual(image0.shape, shape)
@@ -3829,7 +3834,7 @@
"tensorflow/core/lib/psnr/testdata", filename))
im = image_ops.decode_jpeg(content, dct_method="INTEGER_ACCURATE")
im = image_ops.convert_image_dtype(im, dtypes.float32)
- im, = sess.run([im])
+ im, = self.evaluate([im])
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
@@ -3936,7 +3941,7 @@
"tensorflow/core/lib/ssim/testdata", filename))
im = image_ops.decode_png(content)
im = image_ops.convert_image_dtype(im, dtypes.float32)
- im, = sess.run([im])
+ im, = self.evaluate([im])
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
@@ -4028,7 +4033,7 @@
"tensorflow/core/lib/ssim/testdata", filename))
im = image_ops.decode_png(content)
im = image_ops.convert_image_dtype(im, dtypes.float32)
- im, = sess.run([im])
+ im, = self.evaluate([im])
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
@@ -4223,7 +4228,7 @@
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
dtypes.uint16)
- image0, image1 = sess.run([image0, image1])
+ image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testPngUint16(self):
@@ -4233,7 +4238,7 @@
image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
- image0, image1 = sess.run([image0, image1])
+ image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testGifUint16(self):
@@ -4243,7 +4248,7 @@
image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
dtypes.uint16)
- image0, image1 = sess.run([image0, image1])
+ image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testBmpUint16(self):
@@ -4253,7 +4258,7 @@
image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
dtypes.uint16)
- image0, image1 = sess.run([image0, image1])
+ image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testJpegFloat32(self):
@@ -4263,7 +4268,7 @@
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
dtypes.float32)
- image0, image1 = sess.run([image0, image1])
+ image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testPngFloat32(self):
@@ -4273,7 +4278,7 @@
image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
- image0, image1 = sess.run([image0, image1])
+ image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testGifFloat32(self):
@@ -4283,7 +4288,7 @@
image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
dtypes.float32)
- image0, image1 = sess.run([image0, image1])
+ image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
def testBmpFloat32(self):
@@ -4293,7 +4298,7 @@
image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
dtypes.float32)
- image0, image1 = sess.run([image0, image1])
+ image0, image1 = self.evaluate([image0, image1])
self.assertAllEqual(image0, image1)
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 5a1ac67..03d2201 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -216,7 +216,7 @@
dtype = self.dtype
if verify_shape is None:
verify_shape = self._verify_shape
- return constant_op.constant(
+ return constant_op.constant_v1(
self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
def get_config(self):
@@ -360,8 +360,7 @@
A similar calculation for convolutional networks gives an analogous result
with `dim` equal to the product of the first 3 dimensions. When
nonlinearities are present, we need to multiply this by a constant `factor`.
- See [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
- ([pdf](http://arxiv.org/pdf/1412.6558.pdf)) for deeper motivation, experiments
+ See (Sussillo et al., 2014) for deeper motivation, experiments
and the calculation of constants. In section 2.3 there, the constants were
numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
@@ -371,6 +370,10 @@
`tf.set_random_seed`
for behavior.
dtype: The data type. Only floating point types are supported.
+
+ References:
+ [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
+ ([pdf](http://arxiv.org/pdf/1412.6558.pdf))
"""
@deprecated(None,
@@ -532,6 +535,10 @@
`tf.set_random_seed`
for behavior.
dtype: The data type.
+
+ References:
+ [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
+ ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
"""
def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
@@ -576,7 +583,7 @@
The shape of the tensor must have length 3, 4 or 5. The number of input
filters must not exceed the number of output filters. The center pixels of the
tensor form an orthogonal matrix. Other pixels are set to be zero. See
- algorithm 2 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393
+ algorithm 2 in (Xiao et al., 2018).
Args:
@@ -586,6 +593,10 @@
seed: A Python integer. Used to create random seeds. See
`tf.set_random_seed` for behavior.
dtype: The data type.
+
+ References:
+ [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
+ ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
"""
def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
@@ -642,6 +653,10 @@
seed: A Python integer. Used to create random seeds. See
`tf.set_random_seed` for behavior.
dtype: The data type.
+
+ References:
+ [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
+ ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
"""
def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
@@ -698,7 +713,7 @@
filters must not exceed the number of output filters.
The orthogonality(==isometry) is exact when the inputs are circular padded.
There are finite-width effects with non-circular padding (e.g. zero padding).
- See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393
+ See algorithm 1 in (Xiao et al., 2018).
Args:
gain: Multiplicative factor to apply to the orthogonal
@@ -707,6 +722,10 @@
seed: A Python integer. Used to create random seeds. See
`tf.set_random_seed` for behavior.
dtype: The data type.
+
+ References:
+ [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
+ ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
"""
def __call__(self, shape, dtype=None, partition_info=None):
@@ -834,7 +853,7 @@
filters must not exceed the number of output filters.
The orthogonality(==isometry) is exact when the inputs are circular padded.
There are finite-width effects with non-circular padding (e.g. zero padding).
- See algorithm 1 in [Xiao et al., 2018]: https://arxiv.org/abs/1806.05393
+ See algorithm 1 in (Xiao et al., 2018).
Args:
gain: Multiplicative factor to apply to the orthogonal
@@ -844,6 +863,10 @@
`tf.set_random_seed`
for behavior.
dtype: The data type.
+
+ References:
+ [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
+ ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
"""
def __call__(self, shape, dtype=None, partition_info=None):
@@ -951,7 +974,7 @@
filters must not exceed the number of output filters.
The orthogonality(==isometry) is exact when the inputs are circular padded.
There are finite-width effects with non-circular padding (e.g. zero padding).
- See algorithm 1 [Xiao et al., 2018] in: https://arxiv.org/abs/1806.05393
+ See algorithm 1 (Xiao et al., 2018).
Args:
gain: Multiplicative factor to apply to the orthogonal
@@ -960,6 +983,10 @@
seed: A Python integer. Used to create random seeds. See
`tf.set_random_seed` for behavior.
dtype: The data type.
+
+ References:
+ [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
+ ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
"""
def __call__(self, shape, dtype=None, partition_info=None):
@@ -1139,13 +1166,15 @@
where `fan_in` is the number of input units in the weight tensor
and `fan_out` is the number of output units in the weight tensor.
- Reference: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
-
Args:
seed: A Python integer. Used to create random seeds. See
`tf.set_random_seed`
for behavior.
dtype: The data type. Only floating point types are supported.
+
+ References:
+ [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
+ ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
"""
def __init__(self, seed=None, dtype=dtypes.float32):
@@ -1176,12 +1205,14 @@
where `fan_in` is the number of input units in the weight tensor
and `fan_out` is the number of output units in the weight tensor.
- Reference: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
-
Args:
seed: A Python integer. Used to create random seeds. See
`tf.set_random_seed` for behavior.
dtype: The data type. Only floating point types are supported.
+
+ References:
+ [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
+ ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
"""
def __init__(self, seed=None, dtype=dtypes.float32):
@@ -1233,9 +1264,11 @@
An initializer.
References:
- - [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
- - [Efficient
- Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
+ - Self-Normalizing Neural Networks,
+ [Klambauer et al., 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
+ ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
+ - Efficient Backprop,
+ [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
"""
return VarianceScaling(
scale=1., mode="fan_in", distribution="truncated_normal", seed=seed)
@@ -1256,8 +1289,11 @@
An initializer.
References:
- LeCun 98, Efficient Backprop,
- http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf
+ - Self-Normalizing Neural Networks,
+ [Klambauer et al., 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
+ ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
+ - Efficient Backprop,
+ [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
"""
return VarianceScaling(
scale=1., mode="fan_in", distribution="uniform", seed=seed)
@@ -1278,7 +1314,8 @@
An initializer.
References:
- He et al., http://arxiv.org/abs/1502.01852
+ [He et al., 2015](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)
+ ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
"""
return VarianceScaling(
scale=2., mode="fan_in", distribution="truncated_normal", seed=seed)
@@ -1299,7 +1336,8 @@
An initializer.
References:
- He et al., http://arxiv.org/abs/1502.01852
+ [He et al., 2015](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)
+ ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
"""
return VarianceScaling(
scale=2., mode="fan_in", distribution="uniform", seed=seed)
diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py
index 08d50ce..2c9476a9 100644
--- a/tensorflow/python/ops/linalg/linalg_impl.py
+++ b/tensorflow/python/ops/linalg/linalg_impl.py
@@ -88,7 +88,7 @@
chol = gen_linalg_ops.cholesky(matrix)
return 2.0 * math_ops.reduce_sum(
math_ops.log(math_ops.real(array_ops.matrix_diag_part(chol))),
- reduction_indices=[-1])
+ axis=[-1])
@tf_export('linalg.adjoint')
diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py
index 6fb7a57..8efafda3 100644
--- a/tensorflow/python/ops/linalg/linear_operator.py
+++ b/tensorflow/python/ops/linalg/linear_operator.py
@@ -690,7 +690,7 @@
" Requires conversion to a dense matrix and O(N^3) operations.")
if self._can_use_cholesky():
diag = array_ops.matrix_diag_part(linalg_ops.cholesky(self.to_dense()))
- return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
+ return 2 * math_ops.reduce_sum(math_ops.log(diag), axis=[-1])
_, log_abs_det = linalg.slogdet(self.to_dense())
return log_abs_det
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
index 09f0c51..b74baa5 100644
--- a/tensorflow/python/ops/linalg/linear_operator_circulant.py
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -418,15 +418,13 @@
return math_ops.cast(y, self.dtype)
def _determinant(self):
- reduction_indices = [-(i + 1) for i in range(self.block_depth)]
- det = math_ops.reduce_prod(
- self.spectrum, reduction_indices=reduction_indices)
+ axis = [-(i + 1) for i in range(self.block_depth)]
+ det = math_ops.reduce_prod(self.spectrum, axis=axis)
return math_ops.cast(det, self.dtype)
def _log_abs_determinant(self):
- reduction_indices = [-(i + 1) for i in range(self.block_depth)]
- lad = math_ops.reduce_sum(
- math_ops.log(self._abs_spectrum), reduction_indices=reduction_indices)
+ axis = [-(i + 1) for i in range(self.block_depth)]
+ lad = math_ops.reduce_sum(math_ops.log(self._abs_spectrum), axis=axis)
return math_ops.cast(lad, self.dtype)
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py
index ed53dec..be893c7 100644
--- a/tensorflow/python/ops/linalg/linear_operator_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_diag.py
@@ -228,11 +228,11 @@
return diag_mat * x
def _determinant(self):
- return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
+ return math_ops.reduce_prod(self._diag, axis=[-1])
def _log_abs_determinant(self):
log_det = math_ops.reduce_sum(
- math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
+ math_ops.log(math_ops.abs(self._diag)), axis=[-1])
if self.dtype.is_complex:
log_det = math_ops.cast(log_det, dtype=self.dtype)
return log_det
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
index c4288ff..aa0500a 100644
--- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
+++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
@@ -391,7 +391,7 @@
if self._use_cholesky:
chol_cap_diag = array_ops.matrix_diag_part(self._chol_capacitance)
log_abs_det_c = 2 * math_ops.reduce_sum(
- math_ops.log(chol_cap_diag), reduction_indices=[-1])
+ math_ops.log(chol_cap_diag), axis=[-1])
else:
det_c = linalg_ops.matrix_determinant(self._capacitance)
log_abs_det_c = math_ops.log(math_ops.abs(det_c))
diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
index ca6d3f5..d33fe17 100644
--- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
+++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
@@ -195,11 +195,11 @@
self._tril, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
def _determinant(self):
- return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
+ return math_ops.reduce_prod(self._diag, axis=[-1])
def _log_abs_determinant(self):
return math_ops.reduce_sum(
- math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
+ math_ops.log(math_ops.abs(self._diag)), axis=[-1])
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index bbccc7e..1a9e711 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -423,7 +423,78 @@
# pylint: disable=redefined-builtin
-@tf_export('norm', 'linalg.norm')
+@tf_export('norm', 'linalg.norm', v1=[])
+def norm_v2(tensor,
+ ord='euclidean',
+ axis=None,
+ keepdims=None,
+ name=None):
+ r"""Computes the norm of vectors, matrices, and tensors.
+
+ This function can compute several different vector norms (the 1-norm, the
+ Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
+ matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
+
+ Args:
+ tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
+ ord: Order of the norm. Supported values are 'fro', 'euclidean',
+ `1`, `2`, `np.inf` and any positive real number yielding the corresponding
+ p-norm. Default is 'euclidean' which is equivalent to Frobenius norm if
+ `tensor` is a matrix and equivalent to 2-norm for vectors.
+ Some restrictions apply:
+ a) The Frobenius norm `fro` is not defined for vectors,
+ b) If axis is a 2-tuple (matrix norm), only 'euclidean', 'fro', `1`,
+ `2`, `np.inf` are supported.
+ See the description of `axis` on how to compute norms for a batch of
+ vectors or matrices stored in a tensor.
+ axis: If `axis` is `None` (the default), the input is considered a vector
+ and a single vector norm is computed over the entire set of values in the
+ tensor, i.e. `norm(tensor, ord=ord)` is equivalent to
+ `norm(reshape(tensor, [-1]), ord=ord)`.
+ If `axis` is a Python integer, the input is considered a batch of vectors,
+ and `axis` determines the axis in `tensor` over which to compute vector
+ norms.
+ If `axis` is a 2-tuple of Python integers it is considered a batch of
+ matrices and `axis` determines the axes in `tensor` over which to compute
+ a matrix norm.
+ Negative indices are supported. Example: If you are passing a tensor that
+ can be either a matrix or a batch of matrices at runtime, pass
+ `axis=[-2,-1]` instead of `axis=None` to make sure that matrix norms are
+ computed.
+ keepdims: If True, the axis indicated in `axis` are kept with size 1.
+ Otherwise, the dimensions in `axis` are removed from the output shape.
+ name: The name of the op.
+
+ Returns:
+ output: A `Tensor` of the same type as tensor, containing the vector or
+ matrix norms. If `keepdims` is True then the rank of output is equal to
+ the rank of `tensor`. Otherwise, if `axis` is none the output is a scalar,
+ if `axis` is an integer, the rank of `output` is one less than the rank
+ of `tensor`, if `axis` is a 2-tuple the rank of `output` is two less
+ than the rank of `tensor`.
+
+ Raises:
+ ValueError: If `ord` or `axis` is invalid.
+
+ @compatibility(numpy)
+ Mostly equivalent to numpy.linalg.norm.
+ Not supported: ord <= 0, 2-norm for matrices, nuclear norm.
+ Other differences:
+ a) If axis is `None`, treats the flattened `tensor` as a vector
+ regardless of rank.
+ b) Explicitly supports 'euclidean' norm as the default, including for
+ higher order tensors.
+ @end_compatibility
+ """
+ return norm(tensor=tensor,
+ ord=ord,
+ axis=axis,
+ keepdims=keepdims,
+ name=name)
+
+
+# pylint: disable=redefined-builtin
+@tf_export(v1=['norm', 'linalg.norm'])
@deprecation.deprecated_args(
None, 'keep_dims is deprecated, use keepdims instead', 'keep_dims')
def norm(tensor,
diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py
index 5159260..89ff48e 100644
--- a/tensorflow/python/ops/list_ops.py
+++ b/tensorflow/python/ops/list_ops.py
@@ -31,6 +31,8 @@
ops.NotDifferentiable("TensorListConcat")
+ops.NotDifferentiable("TensorListElementShape")
+ops.NotDifferentiable("TensorListLength")
ops.NotDifferentiable("TensorListPushBackBatch")
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index e8cadf9..1b47093 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -33,32 +33,8 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("losses.Reduction", v1=[])
-class ReductionV2(object):
- """Types of loss reduction.
-
- Contains the following values:
- `NONE`: Un-reduced weighted losses with the same shape as input.
- `SUM`: Scalar sum of weighted losses.
- `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
- """
-
- NONE = "none"
- SUM = "weighted_sum"
- SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"
-
- @classmethod
- def all(cls):
- return (cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)
-
- @classmethod
- def validate(cls, key):
- if key not in cls.all():
- raise ValueError("Invalid Reduction Key %s." % key)
-
-
@tf_export(v1=["losses.Reduction"])
-class Reduction(ReductionV2):
+class Reduction(object):
"""Types of loss reduction.
Contains the following values:
@@ -71,6 +47,9 @@
`SUM_BY_NONZERO_WEIGHTS`: Same as `SUM_OVER_NONZERO_WEIGHTS`.
"""
+ NONE = "none"
+ SUM = "weighted_sum"
+ SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"
MEAN = "weighted_mean"
SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"
SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS
@@ -154,7 +133,7 @@
return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
-@tf_export("losses.compute_weighted_loss")
+@tf_export(v1=["losses.compute_weighted_loss"])
def compute_weighted_loss(
losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@@ -224,7 +203,7 @@
return loss
-@tf_export("losses.absolute_difference")
+@tf_export(v1=["losses.absolute_difference"])
def absolute_difference(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -277,7 +256,7 @@
losses, weights, scope, loss_collection, reduction=reduction)
-@tf_export("losses.cosine_distance")
+@tf_export(v1=["losses.cosine_distance"])
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def cosine_distance(
labels, predictions, axis=None, weights=1.0, scope=None,
@@ -333,7 +312,7 @@
losses, weights, scope, loss_collection, reduction=reduction)
-@tf_export("losses.hinge_loss")
+@tf_export(v1=["losses.hinge_loss"])
def hinge_loss(labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@@ -383,7 +362,7 @@
losses, weights, scope, loss_collection, reduction=reduction)
-@tf_export("losses.huber_loss")
+@tf_export(v1=["losses.huber_loss"])
def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@@ -461,7 +440,7 @@
losses, weights, scope, loss_collection, reduction=reduction)
-@tf_export("losses.log_loss")
+@tf_export(v1=["losses.log_loss"])
def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
@@ -518,7 +497,7 @@
# TODO(b/37208492): Add reduction arg.
-@tf_export("losses.mean_pairwise_squared_error")
+@tf_export(v1=["losses.mean_pairwise_squared_error"])
def mean_pairwise_squared_error(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES):
@@ -583,12 +562,10 @@
diffs = math_ops.subtract(predictions, labels)
- reduction_indices = math_ops.range(1, array_ops.rank(diffs))
+ axis = math_ops.range(1, array_ops.rank(diffs))
sum_squares_diff_per_batch = math_ops.reduce_sum(
- math_ops.square(diffs),
- reduction_indices=reduction_indices,
- keepdims=True)
+ math_ops.square(diffs), axis=axis, keepdims=True)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
term1 = 2.0 * math_ops.div_no_nan(
@@ -596,8 +573,7 @@
math_ops.maximum(num_present_per_batch - 1, 0),
name="value")
- sum_diff = math_ops.reduce_sum(
- diffs, reduction_indices=reduction_indices, keepdims=True)
+ sum_diff = math_ops.reduce_sum(diffs, axis=axis, keepdims=True)
term2 = 2.0 * math_ops.div_no_nan(
math_ops.square(sum_diff),
math_ops.maximum(
@@ -617,7 +593,7 @@
return mean_loss
-@tf_export("losses.mean_squared_error")
+@tf_export(v1=["losses.mean_squared_error"])
def mean_squared_error(
labels, predictions, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -670,7 +646,7 @@
losses, weights, scope, loss_collection, reduction=reduction)
-@tf_export("losses.sigmoid_cross_entropy")
+@tf_export(v1=["losses.sigmoid_cross_entropy"])
def sigmoid_cross_entropy(
multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -734,7 +710,7 @@
losses, weights, scope, loss_collection, reduction=reduction)
-@tf_export("losses.softmax_cross_entropy")
+@tf_export(v1=["losses.softmax_cross_entropy"])
def softmax_cross_entropy(
onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -856,7 +832,7 @@
return labels, predictions, weights
-@tf_export("losses.sparse_softmax_cross_entropy")
+@tf_export(v1=["losses.sparse_softmax_cross_entropy"])
def sparse_softmax_cross_entropy(
labels, logits, weights=1.0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 952a2a1..c3feb18 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -82,8 +82,6 @@
output_type=dtypes.int64):
axis = deprecation.deprecated_argument_lookup(
"axis", axis, "dimension", dimension)
- if axis is None:
- axis = 0
return argmax_v2(input, axis, output_type, name)
@@ -111,6 +109,8 @@
Returns:
A `Tensor` of type `output_type`.
"""
+ if axis is None:
+ axis = 0
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
@@ -127,8 +127,6 @@
output_type=dtypes.int64):
axis = deprecation.deprecated_argument_lookup(
"axis", axis, "dimension", dimension)
- if axis is None:
- axis = 0
return argmin_v2(input, axis, output_type, name)
@@ -156,6 +154,8 @@
Returns:
A `Tensor` of type `output_type`.
"""
+ if axis is None:
+ axis = 0
return gen_math_ops.arg_min(input, axis, name=name, output_type=output_type)
@@ -440,8 +440,8 @@
return gen_math_ops.erf(x, name=name)
-@tf_export("math.scalar_mul", "scalar_mul")
-def scalar_mul(scalar, x):
+@tf_export(v1=["math.scalar_mul", "scalar_mul"])
+def scalar_mul(scalar, x, name=None):
"""Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
Intended for use in gradient code which might deal with `IndexedSlices`
@@ -451,6 +451,7 @@
Args:
scalar: A 0-D scalar `Tensor`. Must have known shape.
x: A `Tensor` or `IndexedSlices` to be scaled.
+ name: A name for the operation (optional).
Returns:
`scalar * x` of the same type (`Tensor` or `IndexedSlices`) as `x`.
@@ -463,13 +464,21 @@
shape = scalar.get_shape()
if shape.ndims == 0:
if isinstance(x, ops.IndexedSlices):
- return ops.IndexedSlices(scalar * x.values, x.indices, x.dense_shape)
+ return ops.IndexedSlices(gen_math_ops.mul(scalar, x.values, name),
+ x.indices, x.dense_shape)
else:
- return scalar * x
+ return gen_math_ops.mul(scalar, x, name)
else:
raise ValueError("Only scalar multiply works, got shape %s" % shape)
+@tf_export("math.scalar_mul", "scalar_mul", v1=[])
+@_set_doc(scalar_mul.__doc__)
+def scalar_mul_v2(scalar, x, name=None):
+ with ops.name_scope(name, "scalar_mul", [x]) as name:
+ return scalar_mul(scalar, x, name)
+
+
@tf_export("math.pow", "pow")
def pow(x, y, name=None): # pylint: disable=redefined-builtin
r"""Computes the power of one value to another.
@@ -1314,7 +1323,7 @@
# Reduction operations
-def _ReductionDims(x, axis, reduction_indices):
+def _ReductionDims(x, axis, reduction_indices=None): # pylint: disable=invalid-name
"""Returns range(0, rank(x)) if reduction_indices is None."""
# TODO(aselle): Remove this after deprecation
if reduction_indices is not None:
@@ -1337,23 +1346,23 @@
return range(0, array_ops.rank(x))
-def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
+def _may_reduce_to_scalar(keepdims, axis, output):
"""Set a reduction's output shape to be a scalar if we are certain."""
if not common_shapes.has_fully_defined_shape(output) and (not keepdims) and (
- axis is None) and (reduction_indices is None):
+ axis is None):
output.set_shape(())
return output
-@tf_export("math.reduce_sum", "reduce_sum")
+@tf_export(v1=["math.reduce_sum", "reduce_sum"])
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
-def reduce_sum(input_tensor,
- axis=None,
- keepdims=None,
- name=None,
- reduction_indices=None,
- keep_dims=None):
+def reduce_sum_v1(input_tensor,
+ axis=None,
+ keepdims=None,
+ name=None,
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the sum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1393,18 +1402,58 @@
int64 while tensorflow returns the same dtype as the input.
@end_compatibility
"""
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "reduction_indices", reduction_indices)
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
- if keepdims is None:
- keepdims = False
+ return reduce_sum(input_tensor, axis, keepdims, name)
- return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
- gen_math_ops._sum(
- input_tensor,
- _ReductionDims(input_tensor, axis,
- reduction_indices),
- keepdims,
- name=name))
+
+@tf_export("math.reduce_sum", "reduce_sum", v1=[])
+def reduce_sum(input_tensor, axis=None, keepdims=False, name=None):
+ """Computes the sum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `axis`.
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `axis` is None, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ x = tf.constant([[1, 1, 1], [1, 1, 1]])
+ tf.reduce_sum(x) # 6
+ tf.reduce_sum(x, 0) # [2, 2, 2]
+ tf.reduce_sum(x, 1) # [3, 3]
+ tf.reduce_sum(x, 1, keepdims=True) # [[3], [3]]
+ tf.reduce_sum(x, [0, 1]) # 6
+ ```
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
+ keepdims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor, of the same dtype as the input_tensor.
+
+ @compatibility(numpy)
+ Equivalent to np.sum apart the fact that numpy upcast uint8 and int32 to
+ int64 while tensorflow returns the same dtype as the input.
+ @end_compatibility
+ """
+ keepdims = False if keepdims is None else keepdims
+ return _may_reduce_to_scalar(
+ keepdims, axis,
+ gen_math_ops._sum(
+ input_tensor, _ReductionDims(input_tensor, axis), keepdims,
+ name=name))
@tf_export(v1=["math.count_nonzero", "count_nonzero"])
@@ -1472,8 +1521,6 @@
"axis", axis,
"reduction_indices", reduction_indices
)
- if keepdims is None:
- keepdims = False
return count_nonzero_v2(input_tensor, axis, keepdims, dtype, name)
@@ -1531,6 +1578,8 @@
Returns:
The reduced tensor (number of nonzero values).
"""
+ if keepdims is None:
+ keepdims = False
with ops.name_scope(name, "count_nonzero", [input]):
input = ops.convert_to_tensor(input, name="input")
# A scalar of 'zero' is enough as `not_equal` will broadcast.
@@ -1544,15 +1593,13 @@
dtype=dtype)
-@tf_export("math.reduce_mean", "reduce_mean")
-@deprecation.deprecated_args(
- None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
-def reduce_mean(input_tensor,
- axis=None,
- keepdims=None,
- name=None,
- reduction_indices=None,
- keep_dims=None):
+@tf_export(v1=["math.reduce_mean", "reduce_mean"])
+def reduce_mean_v1(input_tensor,
+ axis=None,
+ keepdims=None,
+ name=None,
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the mean of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1602,22 +1649,72 @@
@end_compatibility
"""
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "reduction_indices", reduction_indices)
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
+ return reduce_mean(input_tensor, axis, keepdims, name)
- if keepdims is None:
- keepdims = False
- return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
- gen_math_ops.mean(
- input_tensor,
- _ReductionDims(input_tensor, axis,
- reduction_indices),
- keepdims,
- name=name))
+
+@tf_export("math.reduce_mean", "reduce_mean", v1=[])
+def reduce_mean(input_tensor, axis=None, keepdims=False, name=None):
+ """Computes the mean of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `axis`.
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `axis` is None, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ x = tf.constant([[1., 1.], [2., 2.]])
+ tf.reduce_mean(x) # 1.5
+ tf.reduce_mean(x, 0) # [1.5, 1.5]
+ tf.reduce_mean(x, 1) # [1., 2.]
+ ```
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
+ keepdims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+
+ @compatibility(numpy)
+ Equivalent to np.mean
+
+ Please note that `np.mean` has a `dtype` parameter that could be used to
+ specify the output type. By default this is `dtype=float64`. On the other
+ hand, `tf.reduce_mean` has an aggressive type inference from `input_tensor`,
+ for example:
+
+ ```python
+ x = tf.constant([1, 0, 1, 0])
+ tf.reduce_mean(x) # 0
+ y = tf.constant([1., 0., 1., 0.])
+ tf.reduce_mean(y) # 0.5
+ ```
+
+ @end_compatibility
+ """
+ keepdims = False if keepdims is None else keepdims
+ return _may_reduce_to_scalar(
+ keepdims, axis,
+ gen_math_ops.mean(
+ input_tensor, _ReductionDims(input_tensor, axis), keepdims,
+ name=name))
@tf_export("math.reduce_variance")
-def reduce_variance(input_tensor, axis=None, keepdims=None, name=None):
+def reduce_variance(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the variance of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1665,7 +1762,7 @@
@tf_export("math.reduce_std")
-def reduce_std(input_tensor, axis=None, keepdims=None, name=None):
+def reduce_std(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the standard deviation of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1710,15 +1807,8 @@
return sqrt(variance)
-@tf_export("math.reduce_prod", "reduce_prod")
-@deprecation.deprecated_args(
- None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
-def reduce_prod(input_tensor,
- axis=None,
- keepdims=None,
- name=None,
- reduction_indices=None,
- keep_dims=None):
+@tf_export("math.reduce_prod", "reduce_prod", v1=[])
+def reduce_prod(input_tensor, axis=None, keepdims=False, name=None):
"""Computes the product of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1736,6 +1826,48 @@
`[-rank(input_tensor), rank(input_tensor))`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+
+ @compatibility(numpy)
+ Equivalent to np.prod
+ @end_compatibility
+ """
+ keepdims = False if keepdims is None else keepdims
+ return _may_reduce_to_scalar(
+ keepdims, axis,
+ gen_math_ops.prod(
+ input_tensor, _ReductionDims(input_tensor, axis), keepdims,
+ name=name))
+
+
+@tf_export(v1=["math.reduce_prod", "reduce_prod"])
+@deprecation.deprecated_args(
+ None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
+def reduce_prod_v1(input_tensor,
+ axis=None,
+ keepdims=None,
+ name=None,
+ reduction_indices=None,
+ keep_dims=None):
+ """Computes the product of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `axis`.
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `axis` is None, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
+ keepdims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
keep_dims: Deprecated alias for `keepdims`.
@@ -1746,29 +1878,22 @@
Equivalent to np.prod
@end_compatibility
"""
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "reduction_indices", reduction_indices)
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
-
- if keepdims is None:
- keepdims = False
- return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
- gen_math_ops.prod(
- input_tensor,
- _ReductionDims(input_tensor, axis,
- reduction_indices),
- keepdims,
- name=name))
+ return reduce_prod(input_tensor, axis, keepdims, name)
-@tf_export("math.reduce_min", "reduce_min")
+@tf_export(v1=["math.reduce_min", "reduce_min"])
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
-def reduce_min(input_tensor,
- axis=None,
- keepdims=None,
- name=None,
- reduction_indices=None,
- keep_dims=None):
+def reduce_min_v1(input_tensor,
+ axis=None,
+ keepdims=None,
+ name=None,
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the minimum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1781,9 +1906,9 @@
Args:
input_tensor: The tensor to reduce. Should have real numeric type.
- axis: The dimensions to reduce. If `None` (the default),
- reduces all dimensions. Must be in the range
- `[-rank(input_tensor), rank(input_tensor))`.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1796,28 +1921,57 @@
Equivalent to np.min
@end_compatibility
"""
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "reduction_indices", reduction_indices)
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
- if keepdims is None:
- keepdims = False
- return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
- gen_math_ops._min(
- input_tensor,
- _ReductionDims(input_tensor, axis,
- reduction_indices),
- keepdims,
- name=name))
+ return reduce_min(input_tensor, axis, keepdims, name)
-@tf_export("math.reduce_max", "reduce_max")
+@tf_export("math.reduce_min", "reduce_min", v1=[])
+def reduce_min(input_tensor, axis=None, keepdims=False, name=None):
+ """Computes the minimum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `axis`.
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `axis` is None, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have real numeric type.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
+ keepdims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+
+ @compatibility(numpy)
+ Equivalent to np.min
+ @end_compatibility
+ """
+ keepdims = False if keepdims is None else keepdims
+ return _may_reduce_to_scalar(
+ keepdims, axis,
+ gen_math_ops._min(
+ input_tensor, _ReductionDims(input_tensor, axis), keepdims,
+ name=name))
+
+
+@tf_export(v1=["math.reduce_max", "reduce_max"])
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
-def reduce_max(input_tensor,
- axis=None,
- keepdims=None,
- name=None,
- reduction_indices=None,
- keep_dims=None):
+def reduce_max_v1(input_tensor,
+ axis=None,
+ keepdims=None,
+ name=None,
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the maximum of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1845,28 +1999,57 @@
Equivalent to np.max
@end_compatibility
"""
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "reduction_indices", reduction_indices)
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
- if keepdims is None:
- keepdims = False
- return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
- gen_math_ops._max(
- input_tensor,
- _ReductionDims(input_tensor, axis,
- reduction_indices),
- keepdims,
- name=name))
+ return reduce_max(input_tensor, axis, keepdims, name)
-@tf_export("math.reduce_all", "reduce_all")
+@tf_export("math.reduce_max", "reduce_max", v1=[])
+def reduce_max(input_tensor, axis=None, keepdims=False, name=None):
+ """Computes the maximum of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `axis`.
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `axis` is None, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ Args:
+ input_tensor: The tensor to reduce. Should have real numeric type.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
+ keepdims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+
+ @compatibility(numpy)
+ Equivalent to np.max
+ @end_compatibility
+ """
+ keepdims = False if keepdims is None else keepdims
+ return _may_reduce_to_scalar(
+ keepdims, axis,
+ gen_math_ops._max(
+ input_tensor, _ReductionDims(input_tensor, axis), keepdims,
+ name=name))
+
+
+@tf_export(v1=["math.reduce_all", "reduce_all"])
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
-def reduce_all(input_tensor,
- axis=None,
- keepdims=None,
- name=None,
- reduction_indices=None,
- keep_dims=None):
+def reduce_all_v1(input_tensor,
+ axis=None,
+ keepdims=None,
+ name=None,
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the "logical and" of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1888,9 +2071,9 @@
Args:
input_tensor: The boolean tensor to reduce.
- axis: The dimensions to reduce. If `None` (the default),
- reduces all dimensions. Must be in the range
- `[-rank(input_tensor), rank(input_tensor))`.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1903,28 +2086,66 @@
Equivalent to np.all
@end_compatibility
"""
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "reduction_indices", reduction_indices)
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
- if keepdims is None:
- keepdims = False
- return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
- gen_math_ops._all(
- input_tensor,
- _ReductionDims(input_tensor, axis,
- reduction_indices),
- keepdims,
- name=name))
+ return reduce_all(input_tensor, axis, keepdims, name)
-@tf_export("math.reduce_any", "reduce_any")
+@tf_export("reduce_all", "math.reduce_all", v1=[])
+def reduce_all(input_tensor, axis=None, keepdims=False, name=None):
+ """Computes the "logical and" of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `axis`.
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `axis` is None, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ x = tf.constant([[True, True], [False, False]])
+ tf.reduce_all(x) # False
+ tf.reduce_all(x, 0) # [False, False]
+ tf.reduce_all(x, 1) # [True, False]
+ ```
+
+ Args:
+ input_tensor: The boolean tensor to reduce.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
+ keepdims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+
+ @compatibility(numpy)
+ Equivalent to np.all
+ @end_compatibility
+ """
+ keepdims = False if keepdims is None else keepdims
+ return _may_reduce_to_scalar(
+ keepdims, axis,
+ gen_math_ops._all(
+ input_tensor, _ReductionDims(input_tensor, axis), keepdims,
+ name=name))
+
+
+@tf_export(v1=["math.reduce_any", "reduce_any"])
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
-def reduce_any(input_tensor,
- axis=None,
- keepdims=None,
- name=None,
- reduction_indices=None,
- keep_dims=None):
+def reduce_any_v1(input_tensor,
+ axis=None,
+ keepdims=None,
+ name=None,
+ reduction_indices=None,
+ keep_dims=None):
"""Computes the "logical or" of elements across dimensions of a tensor.
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -1946,9 +2167,9 @@
Args:
input_tensor: The boolean tensor to reduce.
- axis: The dimensions to reduce. If `None` (the default),
- reduces all dimensions. Must be in the range
- `[-rank(input_tensor), rank(input_tensor))`.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -1961,28 +2182,66 @@
Equivalent to np.any
@end_compatibility
"""
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "reduction_indices", reduction_indices)
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
- if keepdims is None:
- keepdims = False
- return _may_reduce_to_scalar(keepdims, axis, reduction_indices,
- gen_math_ops._any(
- input_tensor,
- _ReductionDims(input_tensor, axis,
- reduction_indices),
- keepdims,
- name=name))
+ return reduce_any(input_tensor, axis, keepdims, name)
-@tf_export("math.reduce_logsumexp", "reduce_logsumexp")
+@tf_export("math.reduce_any", "reduce_any", v1=[])
+def reduce_any(input_tensor, axis=None, keepdims=False, name=None):
+ """Computes the "logical or" of elements across dimensions of a tensor.
+
+ Reduces `input_tensor` along the dimensions given in `axis`.
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `axis` is None, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ For example:
+
+ ```python
+ x = tf.constant([[True, True], [False, False]])
+ tf.reduce_any(x) # True
+ tf.reduce_any(x, 0) # [True, True]
+ tf.reduce_any(x, 1) # [True, False]
+ ```
+
+ Args:
+ input_tensor: The boolean tensor to reduce.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
+ keepdims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+
+ @compatibility(numpy)
+ Equivalent to np.any
+ @end_compatibility
+ """
+ keepdims = False if keepdims is None else keepdims
+ return _may_reduce_to_scalar(
+ keepdims, axis,
+ gen_math_ops._any(
+ input_tensor, _ReductionDims(input_tensor, axis), keepdims,
+ name=name))
+
+
+@tf_export(v1=["math.reduce_logsumexp", "reduce_logsumexp"])
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
-def reduce_logsumexp(input_tensor,
- axis=None,
- keepdims=None,
- name=None,
- reduction_indices=None,
- keep_dims=None):
+def reduce_logsumexp_v1(input_tensor,
+ axis=None,
+ keepdims=None,
+ name=None,
+ reduction_indices=None,
+ keep_dims=None):
"""Computes log(sum(exp(elements across dimensions of a tensor))).
Reduces `input_tensor` along the dimensions given in `axis`.
@@ -2010,9 +2269,9 @@
Args:
input_tensor: The tensor to reduce. Should have numeric type.
- axis: The dimensions to reduce. If `None` (the default),
- reduces all dimensions. Must be in the range
- `[-rank(input_tensor), rank(input_tensor))`.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
@@ -2021,16 +2280,57 @@
Returns:
The reduced tensor.
"""
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "reduction_indices", reduction_indices)
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
- if keepdims is None:
- keepdims = False
+ return reduce_logsumexp(input_tensor, axis, keepdims, name)
+
+
+@tf_export("math.reduce_logsumexp", "reduce_logsumexp", v1=[])
+def reduce_logsumexp(input_tensor, axis=None, keepdims=False, name=None):
+ """Computes log(sum(exp(elements across dimensions of a tensor))).
+
+ Reduces `input_tensor` along the dimensions given in `axis`.
+ Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
+ entry in `axis`. If `keepdims` is true, the reduced dimensions
+ are retained with length 1.
+
+ If `axis` has no entries, all dimensions are reduced, and a
+ tensor with a single element is returned.
+
+ This function is more numerically stable than log(sum(exp(input))). It avoids
+ overflows caused by taking the exp of large inputs and underflows caused by
+ taking the log of small inputs.
+
+ For example:
+
+ ```python
+ x = tf.constant([[0., 0., 0.], [0., 0., 0.]])
+ tf.reduce_logsumexp(x) # log(6)
+ tf.reduce_logsumexp(x, 0) # [log(2), log(2), log(2)]
+ tf.reduce_logsumexp(x, 1) # [log(3), log(3)]
+ tf.reduce_logsumexp(x, 1, keepdims=True) # [[log(3)], [log(3)]]
+ tf.reduce_logsumexp(x, [0, 1]) # log(6)
+ ```
+
+ Args:
+ input_tensor: The tensor to reduce. Should have numeric type.
+ axis: The dimensions to reduce. If `None` (the default), reduces all
+ dimensions. Must be in the range `[-rank(input_tensor),
+ rank(input_tensor))`.
+ keepdims: If true, retains reduced dimensions with length 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ The reduced tensor.
+ """
+ keepdims = False if keepdims is None else keepdims
input_tensor = ops.convert_to_tensor(input_tensor)
with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
raw_max = reduce_max(
input_tensor,
axis=axis,
- reduction_indices=reduction_indices,
keepdims=True)
my_max = array_ops.stop_gradient(
array_ops.where(
@@ -2040,12 +2340,11 @@
reduce_sum(
gen_math_ops.exp(gen_math_ops.sub(input_tensor, my_max)),
axis,
- keepdims=keepdims,
- reduction_indices=reduction_indices))
+ keepdims=keepdims))
if not keepdims:
my_max = array_ops.reshape(my_max, array_ops.shape(result))
result = gen_math_ops.add(result, my_max)
- return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
+ return _may_reduce_to_scalar(keepdims, axis, result)
@tf_export("linalg.trace", v1=["linalg.trace", "trace"])
@@ -2637,13 +2936,13 @@
return gen_math_ops.tanh(x, name=name)
-@tf_export("math.bincount", v1=["math.bincount", "bincount"])
-@deprecation.deprecated_endpoints("bincount")
+@tf_export("math.bincount", v1=[])
def bincount(arr,
weights=None,
minlength=None,
maxlength=None,
- dtype=dtypes.int32):
+ dtype=dtypes.int32,
+ name=None):
"""Counts the number of occurrences of each value in an integer array.
If `minlength` and `maxlength` are not given, returns a vector with length
@@ -2655,34 +2954,70 @@
Args:
arr: An int32 tensor of non-negative values.
weights: If non-None, must be the same shape as arr. For each value in
- `arr`, the bin will be incremented by the corresponding weight instead
- of 1.
+ `arr`, the bin will be incremented by the corresponding weight instead of
+ 1.
minlength: If given, ensures the output has length at least `minlength`,
- padding with zeros at the end if necessary.
+ padding with zeros at the end if necessary.
maxlength: If given, skips values in `arr` that are equal or greater than
- `maxlength`, ensuring that the output has length at most `maxlength`.
+ `maxlength`, ensuring that the output has length at most `maxlength`.
+ dtype: If `weights` is None, determines the type of the output bins.
+ name: A name scope for the associated operations (optional).
+
+ Returns:
+ A vector with the same dtype as `weights` or the given `dtype`. The bin
+ values.
+ """
+ name = "bincount" if name is None else name
+ with ops.name_scope(name):
+ arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
+ array_is_nonempty = reduce_prod(array_ops.shape(arr)) > 0
+ output_size = cast(array_is_nonempty, dtypes.int32) * (reduce_max(arr) + 1)
+ if minlength is not None:
+ minlength = ops.convert_to_tensor(
+ minlength, name="minlength", dtype=dtypes.int32)
+ output_size = gen_math_ops.maximum(minlength, output_size)
+ if maxlength is not None:
+ maxlength = ops.convert_to_tensor(
+ maxlength, name="maxlength", dtype=dtypes.int32)
+ output_size = gen_math_ops.minimum(maxlength, output_size)
+ if weights is not None:
+ weights = ops.convert_to_tensor(weights, name="weights")
+ return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
+ weights = constant_op.constant([], dtype)
+ return gen_math_ops.bincount(arr, output_size, weights)
+
+
+@tf_export(v1=["math.bincount", "bincount"])
+@deprecation.deprecated_endpoints("bincount")
+def bincount_v1(arr,
+ weights=None,
+ minlength=None,
+ maxlength=None,
+ dtype=dtypes.int32):
+ """Counts the number of occurrences of each value in an integer array.
+
+ If `minlength` and `maxlength` are not given, returns a vector with length
+ `tf.reduce_max(arr) + 1` if `arr` is non-empty, and length 0 otherwise.
+ If `weights` are non-None, then index `i` of the output stores the sum of the
+ value in `weights` at each index where the corresponding value in `arr` is
+ `i`.
+
+ Args:
+ arr: An int32 tensor of non-negative values.
+ weights: If non-None, must be the same shape as arr. For each value in
+ `arr`, the bin will be incremented by the corresponding weight instead of
+ 1.
+ minlength: If given, ensures the output has length at least `minlength`,
+ padding with zeros at the end if necessary.
+ maxlength: If given, skips values in `arr` that are equal or greater than
+ `maxlength`, ensuring that the output has length at most `maxlength`.
dtype: If `weights` is None, determines the type of the output bins.
Returns:
A vector with the same dtype as `weights` or the given `dtype`. The bin
values.
"""
- arr = ops.convert_to_tensor(arr, name="arr", dtype=dtypes.int32)
- array_is_nonempty = reduce_prod(array_ops.shape(arr)) > 0
- output_size = cast(array_is_nonempty, dtypes.int32) * (reduce_max(arr) + 1)
- if minlength is not None:
- minlength = ops.convert_to_tensor(
- minlength, name="minlength", dtype=dtypes.int32)
- output_size = gen_math_ops.maximum(minlength, output_size)
- if maxlength is not None:
- maxlength = ops.convert_to_tensor(
- maxlength, name="maxlength", dtype=dtypes.int32)
- output_size = gen_math_ops.minimum(maxlength, output_size)
- if weights is not None:
- weights = ops.convert_to_tensor(weights, name="weights")
- return gen_math_ops.unsorted_segment_sum(weights, arr, output_size)
- weights = constant_op.constant([], dtype)
- return gen_math_ops.bincount(arr, output_size, weights)
+ return bincount(arr, weights, minlength, maxlength, dtype)
@tf_export("math.cumsum", "cumsum")
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index e0329f6..cd45b6f 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -104,7 +104,7 @@
for dtype in [np.float16, np.float32, np.double]:
x_np = np.random.rand(5, 5).astype(dtype)
with self.cached_session(use_gpu=True):
- y_tf = math_ops.reduce_logsumexp(x_np, reduction_indices=[0])
+ y_tf = math_ops.reduce_logsumexp(x_np, axis=[0])
y_np = log(np.sum(exp(x_np), axis=0))
self.assertShapeEqual(y_np, y_tf)
y_tf_np = self.evaluate(y_tf)
@@ -114,7 +114,7 @@
for dtype in [np.float16, np.float32, np.double]:
x_np = np.random.rand(5, 5).astype(dtype)
with self.cached_session(use_gpu=True):
- y_tf = math_ops.reduce_logsumexp(x_np, reduction_indices=0)
+ y_tf = math_ops.reduce_logsumexp(x_np, axis=0)
y_np = log(np.sum(exp(x_np), axis=0))
self.assertShapeEqual(y_np, y_tf)
y_tf_np = self.evaluate(y_tf)
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index 03de8d5..cb42199 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -312,7 +312,7 @@
fn, args=args)
-@tf_export('metrics.mean')
+@tf_export(v1=['metrics.mean'])
def mean(values,
weights=None,
metrics_collections=None,
@@ -393,7 +393,7 @@
return mean_t, update_op
-@tf_export('metrics.accuracy')
+@tf_export(v1=['metrics.accuracy'])
def accuracy(labels,
predictions,
weights=None,
@@ -625,7 +625,7 @@
return _aggregate_across_replicas(collections, f, v)
-@tf_export('metrics.auc')
+@tf_export(v1=['metrics.auc'])
def auc(labels,
predictions,
weights=None,
@@ -830,7 +830,7 @@
return auc_value, update_op
-@tf_export('metrics.mean_absolute_error')
+@tf_export(v1=['metrics.mean_absolute_error'])
def mean_absolute_error(labels,
predictions,
weights=None,
@@ -891,7 +891,7 @@
updates_collections, name or 'mean_absolute_error')
-@tf_export('metrics.mean_cosine_distance')
+@tf_export(v1=['metrics.mean_cosine_distance'])
def mean_cosine_distance(labels,
predictions,
dim,
@@ -948,7 +948,7 @@
predictions=predictions, labels=labels, weights=weights)
radial_diffs = math_ops.multiply(predictions, labels)
radial_diffs = math_ops.reduce_sum(
- radial_diffs, reduction_indices=[
+ radial_diffs, axis=[
dim,
], keepdims=True)
mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
@@ -965,7 +965,7 @@
return mean_distance, update_op
-@tf_export('metrics.mean_per_class_accuracy')
+@tf_export(v1=['metrics.mean_per_class_accuracy'])
def mean_per_class_accuracy(labels,
predictions,
num_classes,
@@ -1069,7 +1069,7 @@
return mean_accuracy_v, update_op
-@tf_export('metrics.mean_iou')
+@tf_export(v1=['metrics.mean_iou'])
def mean_iou(labels,
predictions,
num_classes,
@@ -1170,7 +1170,7 @@
return mean_iou_v, update_op
-@tf_export('metrics.mean_relative_error')
+@tf_export(v1=['metrics.mean_relative_error'])
def mean_relative_error(labels,
predictions,
normalizer,
@@ -1239,7 +1239,7 @@
updates_collections, name or 'mean_relative_error')
-@tf_export('metrics.mean_squared_error')
+@tf_export(v1=['metrics.mean_squared_error'])
def mean_squared_error(labels,
predictions,
weights=None,
@@ -1300,7 +1300,7 @@
name or 'mean_squared_error')
-@tf_export('metrics.mean_tensor')
+@tf_export(v1=['metrics.mean_tensor'])
def mean_tensor(values,
weights=None,
metrics_collections=None,
@@ -1385,7 +1385,7 @@
return mean_t, update_op
-@tf_export('metrics.percentage_below')
+@tf_export(v1=['metrics.percentage_below'])
def percentage_below(values,
threshold,
weights=None,
@@ -1485,7 +1485,7 @@
return value_tensor, update_op
-@tf_export('metrics.false_negatives')
+@tf_export(v1=['metrics.false_negatives'])
def false_negatives(labels,
predictions,
weights=None,
@@ -1537,7 +1537,7 @@
updates_collections)
-@tf_export('metrics.false_negatives_at_thresholds')
+@tf_export(v1=['metrics.false_negatives_at_thresholds'])
def false_negatives_at_thresholds(labels,
predictions,
thresholds,
@@ -1593,7 +1593,7 @@
return fn_value, update_ops['fn']
-@tf_export('metrics.false_positives')
+@tf_export(v1=['metrics.false_positives'])
def false_positives(labels,
predictions,
weights=None,
@@ -1646,7 +1646,7 @@
updates_collections)
-@tf_export('metrics.false_positives_at_thresholds')
+@tf_export(v1=['metrics.false_positives_at_thresholds'])
def false_positives_at_thresholds(labels,
predictions,
thresholds,
@@ -1702,7 +1702,7 @@
return fp_value, update_ops['fp']
-@tf_export('metrics.true_negatives')
+@tf_export(v1=['metrics.true_negatives'])
def true_negatives(labels,
predictions,
weights=None,
@@ -1755,7 +1755,7 @@
updates_collections)
-@tf_export('metrics.true_negatives_at_thresholds')
+@tf_export(v1=['metrics.true_negatives_at_thresholds'])
def true_negatives_at_thresholds(labels,
predictions,
thresholds,
@@ -1811,7 +1811,7 @@
return tn_value, update_ops['tn']
-@tf_export('metrics.true_positives')
+@tf_export(v1=['metrics.true_positives'])
def true_positives(labels,
predictions,
weights=None,
@@ -1864,7 +1864,7 @@
updates_collections)
-@tf_export('metrics.true_positives_at_thresholds')
+@tf_export(v1=['metrics.true_positives_at_thresholds'])
def true_positives_at_thresholds(labels,
predictions,
thresholds,
@@ -1920,7 +1920,7 @@
return tp_value, update_ops['tp']
-@tf_export('metrics.precision')
+@tf_export(v1=['metrics.precision'])
def precision(labels,
predictions,
weights=None,
@@ -2015,7 +2015,7 @@
return p, update_op
-@tf_export('metrics.precision_at_thresholds')
+@tf_export(v1=['metrics.precision_at_thresholds'])
def precision_at_thresholds(labels,
predictions,
thresholds,
@@ -2096,7 +2096,7 @@
return prec, update_op
-@tf_export('metrics.recall')
+@tf_export(v1=['metrics.recall'])
def recall(labels,
predictions,
weights=None,
@@ -2447,7 +2447,7 @@
return var, state_ops.assign_add(var, batch_total_fn, name='update')
-@tf_export('metrics.recall_at_k')
+@tf_export(v1=['metrics.recall_at_k'])
def recall_at_k(labels,
predictions,
k,
@@ -2540,7 +2540,7 @@
name=scope)
-@tf_export('metrics.recall_at_top_k')
+@tf_export(v1=['metrics.recall_at_top_k'])
def recall_at_top_k(labels,
predictions_idx,
k=None,
@@ -2624,7 +2624,7 @@
return metric, update
-@tf_export('metrics.recall_at_thresholds')
+@tf_export(v1=['metrics.recall_at_thresholds'])
def recall_at_thresholds(labels,
predictions,
thresholds,
@@ -2702,7 +2702,7 @@
return rec, update_op
-@tf_export('metrics.root_mean_squared_error')
+@tf_export(v1=['metrics.root_mean_squared_error'])
def root_mean_squared_error(labels,
predictions,
weights=None,
@@ -2773,7 +2773,7 @@
return rmse, update_rmse_op
-@tf_export('metrics.sensitivity_at_specificity')
+@tf_export(v1=['metrics.sensitivity_at_specificity'])
def sensitivity_at_specificity(labels,
predictions,
specificity,
@@ -3045,7 +3045,7 @@
# Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
precision_sum = math_ops.reduce_sum(
- relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum')
+ relevant_precision_per_k, axis=(-1,), name='precision_sum')
# Divide by number of relevant items to get average precision. These are
# the "num_relevant_items" and "AveP" terms from the formula above.
@@ -3146,7 +3146,7 @@
return mean_average_precision, update
-@tf_export('metrics.sparse_average_precision_at_k')
+@tf_export(v1=['metrics.sparse_average_precision_at_k'])
@deprecated(None, 'Use average_precision_at_k instead')
def sparse_average_precision_at_k(labels,
predictions,
@@ -3166,7 +3166,7 @@
name=name)
-@tf_export('metrics.average_precision_at_k')
+@tf_export(v1=['metrics.average_precision_at_k'])
def average_precision_at_k(labels,
predictions,
k,
@@ -3340,7 +3340,7 @@
return var, state_ops.assign_add(var, batch_total_fp, name='update')
-@tf_export('metrics.precision_at_top_k')
+@tf_export(v1=['metrics.precision_at_top_k'])
def precision_at_top_k(labels,
predictions_idx,
k=None,
@@ -3429,7 +3429,7 @@
return metric, update
-@tf_export('metrics.sparse_precision_at_k')
+@tf_export(v1=['metrics.sparse_precision_at_k'])
@deprecated(None, 'Use precision_at_k instead')
def sparse_precision_at_k(labels,
predictions,
@@ -3451,7 +3451,7 @@
name=name)
-@tf_export('metrics.precision_at_k')
+@tf_export(v1=['metrics.precision_at_k'])
def precision_at_k(labels,
predictions,
k,
@@ -3545,7 +3545,7 @@
name=scope)
-@tf_export('metrics.specificity_at_sensitivity')
+@tf_export(v1=['metrics.specificity_at_sensitivity'])
def specificity_at_sensitivity(labels,
predictions,
sensitivity,
diff --git a/tensorflow/python/ops/nccl_ops_test.py b/tensorflow/python/ops/nccl_ops_test.py
index 1b496fe..3b2e2b0 100644
--- a/tensorflow/python/ops/nccl_ops_test.py
+++ b/tensorflow/python/ops/nccl_ops_test.py
@@ -102,7 +102,7 @@
continue
# Test execution and results.
- for t in sess.run(result_tensors):
+ for t in self.evaluate(result_tensors):
self.assertAllClose(t, np_ans)
def _TestGradient(self, nccl_reduce, numpy_fn):
diff --git a/tensorflow/python/ops/nn_batchnorm_test.py b/tensorflow/python/ops/nn_batchnorm_test.py
index b50bccf..31b2790 100644
--- a/tensorflow/python/ops/nn_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_batchnorm_test.py
@@ -235,10 +235,11 @@
odx, odm, odv, odb, odg = gradients_impl.gradients(
[on], [x, m, v, beta, gamma], [backprop])
if scale_after_normalization:
- all_grads = sess.run([dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
+ all_grads = self.evaluate(
+ [dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
to_check = ["dx", "dm", "dv", "db", "dg"]
else:
- all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
+ all_grads = self.evaluate([dx, dm, dv, db, odx, odm, odv, odb])
to_check = ["dx", "dm", "dv", "db"]
for i, _ in enumerate(to_check):
self.assertAllClose(
@@ -318,7 +319,7 @@
gamma_val, epsilon,
scale_after_normalization,
shift_after_normalization)
- [tf_batch_norm] = sess.run([bn])
+ [tf_batch_norm] = self.evaluate([bn])
self.assertEquals(x_shape, np_batch_norm.shape)
self.assertEquals(x_shape, tf_batch_norm.shape)
self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
@@ -371,9 +372,9 @@
x.set_shape(x_shape)
op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
if shift:
- tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s])
+ tf_c, tf_m, tf_v, tf_s = self.evaluate([op_c, op_m, op_v, op_s])
else:
- tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v])
+ tf_c, tf_m, tf_v = self.evaluate([op_c, op_m, op_v])
else:
x = array_ops.placeholder(
dtype=dtypes.float32, shape=[None] * len(x_shape), name="x")
@@ -432,7 +433,7 @@
tf_shift_v = None
opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
tf_variance_ss, tf_shift_v)
- tfm, tfv = sess.run([opm, opv])
+ tfm, tfv = self.evaluate([opm, opv])
self.assertAllClose(npm, tfm, atol=0.000001)
self.assertAllClose(npv, tfv, atol=0.000001)
diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py
index 552b274..4bc33ff 100644
--- a/tensorflow/python/ops/nn_fused_batchnorm_test.py
+++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py
@@ -127,7 +127,7 @@
epsilon=epsilon,
data_format=data_format,
is_training=True)
- y_val, mean_val, var_val = sess.run([y, mean, var])
+ y_val, mean_val, var_val = self.evaluate([y, mean, var])
y_ref, mean_ref, var_ref = self._training_ref(x, scale, offset, epsilon,
data_format)
y_atol = 2e-3 if x_dtype == np.float16 else 1e-3
@@ -277,10 +277,10 @@
if is_training:
epsilon = y.op.get_attr('epsilon')
data_format = y.op.get_attr('data_format')
- grad_vals = sess.run([grad_x, grad_scale, grad_offset])
+ grad_vals = self.evaluate([grad_x, grad_scale, grad_offset])
grad_internal = nn_grad._BatchNormGrad(grad_y, x, scale, pop_mean,
pop_var, epsilon, data_format)
- grad_internal_vals = sess.run(list(grad_internal))
+ grad_internal_vals = self.evaluate(list(grad_internal))
for grad_val, grad_internal_val in zip(grad_vals, grad_internal_vals):
self.assertAllClose(grad_val, grad_internal_val, atol=err_tolerance)
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 902653b..34404ed 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -18,13 +18,13 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
-from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
@@ -948,10 +948,14 @@
grad_grad_x = grad[0]
grad_grad_scale = grad[1]
grad_grad_offset = grad[2]
- grad_x, grad_scale, grad_offset = _BatchNormGrad(
- grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
- grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
- grad_grad_y, grad_x, grad_scale = gradients_impl.gradients(
+ with backprop.GradientTape() as tape:
+ tape.watch(grad_y)
+ tape.watch(x)
+ tape.watch(scale)
+ grad_x, grad_scale, grad_offset = _BatchNormGrad(
+ grad_y, x, scale, pop_mean, pop_var, epsilon, data_format, is_training)
+ grad_initial = [grad_grad_x, grad_grad_scale, grad_grad_offset]
+ grad_grad_y, grad_x, grad_scale = tape.gradient(
[grad_x, grad_scale, grad_offset], [grad_y, x, scale], grad_initial)
return grad_grad_y, grad_x, grad_scale, None, None
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 9cf53f1..6591da5 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -329,7 +329,7 @@
return features * math_ops.sigmoid(features)
-@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize")
+@tf_export(v1=["math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize"])
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
"""Normalizes along dimension `axis` using an L2 norm.
@@ -353,8 +353,33 @@
Returns:
A `Tensor` with the same shape as `x`.
"""
+ axis = deprecated_argument_lookup("axis", axis, "dim", dim)
+ return l2_normalize_v2(x, axis, epsilon, name)
+
+
+@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize", v1=[])
+def l2_normalize_v2(x, axis=None, epsilon=1e-12, name=None):
+ """Normalizes along dimension `axis` using an L2 norm.
+
+ For a 1-D tensor with `axis = 0`, computes
+
+ output = x / sqrt(max(sum(x**2), epsilon))
+
+ For `x` with more dimensions, independently normalizes each 1-D slice along
+ dimension `axis`.
+
+ Args:
+ x: A `Tensor`.
+ axis: Dimension along which to normalize. A scalar or a vector of
+ integers.
+ epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
+ divisor if `norm < sqrt(epsilon)`.
+ name: A name for this operation (optional).
+
+ Returns:
+ A `Tensor` with the same shape as `x`.
+ """
with ops.name_scope(name, "l2_normalize", [x]) as name:
- axis = deprecated_argument_lookup("axis", axis, "dim", dim)
x = ops.convert_to_tensor(x, name="x")
square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True)
x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 21008fc..2ffe381 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -207,6 +207,73 @@
name=self.name)
+@tf_export("nn.dilation2d", v1=[])
+def dilation2d_v2(
+ input, # pylint: disable=redefined-builtin
+ filters, # pylint: disable=redefined-builtin
+ strides,
+ padding,
+ data_format,
+ dilations,
+ name=None):
+ """Computes the grayscale dilation of 4-D `input` and 3-D `filters` tensors.
+
+ The `input` tensor has shape `[batch, in_height, in_width, depth]` and the
+ `filters` tensor has shape `[filter_height, filter_width, depth]`, i.e., each
+ input channel is processed independently of the others with its own
+ structuring function. The `output` tensor has shape
+ `[batch, out_height, out_width, depth]`. The spatial dimensions of the output
+ tensor depend on the `padding` algorithm. We currently only support the
+ default "NHWC" `data_format`.
+
+ In detail, the grayscale morphological 2-D dilation is the max-sum correlation
+ (for consistency with `conv2d`, we use unmirrored filters):
+
+ output[b, y, x, c] =
+ max_{dy, dx} input[b,
+ strides[1] * y + rates[1] * dy,
+ strides[2] * x + rates[2] * dx,
+ c] +
+ filters[dy, dx, c]
+
+ Max-pooling is a special case when the filter has size equal to the pooling
+ kernel size and contains all zeros.
+
+ Note on duality: The dilation of `input` by the `filters` is equal to the
+ negation of the erosion of `-input` by the reflected `filters`.
+
+ Args:
+ input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
+ `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`,
+ `uint32`, `uint64`.
+ 4-D with shape `[batch, in_height, in_width, depth]`.
+ filters: A `Tensor`. Must have the same type as `input`.
+ 3-D with shape `[filter_height, filter_width, depth]`.
+ strides: A list of `ints` that has length `>= 4`.
+ The stride of the sliding window for each dimension of the input
+ tensor. Must be: `[1, stride_height, stride_width, 1]`.
+ padding: A `string` from: `"SAME", "VALID"`.
+ The type of padding algorithm to use.
+ data_format: A `string`, only `"NCHW"` is currently supported.
+ dilations: A list of `ints` that has length `>= 4`.
+ The input stride for atrous morphological dilation. Must be:
+ `[1, rate_height, rate_width, 1]`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `input`.
+ """
+ if data_format != "NCHW":
+ raise ValueError("Data formats other than NCHW are not yet supported")
+
+ return gen_nn_ops.dilation2d(input=input,
+ filter=filters,
+ strides=strides,
+ rates=dilations,
+ padding=padding,
+ name=name)
+
+
@tf_export("nn.with_space_to_batch")
def with_space_to_batch(
input, # pylint: disable=redefined-builtin
@@ -2118,7 +2185,7 @@
return output
-@tf_export("nn.softmax", "math.softmax")
+@tf_export(v1=["nn.softmax", "math.softmax"])
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations.
@@ -2148,7 +2215,34 @@
return _softmax(logits, gen_nn_ops.softmax, axis, name)
-@tf_export("nn.log_softmax", "math.log_softmax")
+@tf_export("nn.softmax", "math.softmax", v1=[])
+def softmax_v2(logits, axis=None, name=None):
+ """Computes softmax activations.
+
+ This function performs the equivalent of
+
+ softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
+
+ Args:
+ logits: A non-empty `Tensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ axis: The dimension softmax would be performed on. The default is -1 which
+ indicates the last dimension.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type and shape as `logits`.
+
+ Raises:
+ InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
+ dimension of `logits`.
+ """
+ if axis is None:
+ axis = -1
+ return _softmax(logits, gen_nn_ops.softmax, axis, name)
+
+
+@tf_export(v1=["nn.log_softmax", "math.log_softmax"])
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def log_softmax(logits, axis=None, name=None, dim=None):
"""Computes log softmax activations.
@@ -2178,6 +2272,33 @@
return _softmax(logits, gen_nn_ops.log_softmax, axis, name)
+@tf_export("nn.log_softmax", "math.log_softmax", v1=[])
+def log_softmax_v2(logits, axis=None, name=None):
+ """Computes log softmax activations.
+
+ For each batch `i` and class `j` we have
+
+ logsoftmax = logits - log(reduce_sum(exp(logits), axis))
+
+ Args:
+ logits: A non-empty `Tensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ axis: The dimension softmax would be performed on. The default is -1 which
+ indicates the last dimension.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
+
+ Raises:
+ InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
+ dimension of `logits`.
+ """
+ if axis is None:
+ axis = -1
+ return _softmax(logits, gen_nn_ops.log_softmax, axis, name)
+
+
def _ensure_xent_args(name, sentinel, labels, logits):
# Make sure that all arguments were passed as named arguments.
if sentinel is not None:
@@ -2558,6 +2679,67 @@
name=name)
+# pylint: disable=redefined-builtin
+@tf_export("nn.max_pool_with_argmax", v1=[])
+def max_pool_with_argmax_v2(input,
+ ksize,
+ strides,
+ padding,
+ data_format="NHWC",
+ output_dtype=dtypes.int64,
+ name=None):
+ """Performs max pooling on the input and outputs both max values and indices.
+
+ The indices in `argmax` are flattened, so that a maximum value at position
+ `[b, y, x, c]` becomes flattened index
+ `((b * height + y) * width + x) * channels + c`.
+
+ The indices returned are always in `[0, height) x [0, width)` before
+ flattening, even if padding is involved and the mathematically correct answer
+ is outside (either negative or too large). This is a bug, but fixing it is
+ difficult to do in a safe backwards compatible way, especially due to
+ flattening.
+
+ Args:
+ input: A `Tensor`. Must be one of the following types: `float32`, `float64`,
+ `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`,
+ `uint32`, `uint64`.
+ 4-D with shape `[batch, height, width, channels]`. Input to pool over.
+ ksize: A list of `ints` that has length `>= 4`.
+ The size of the window for each dimension of the input tensor.
+ strides: A list of `ints` that has length `>= 4`.
+ The stride of the sliding window for each dimension of the
+ input tensor.
+ padding: A `string` from: `"SAME", "VALID"`.
+ The type of padding algorithm to use.
+ data_format: An optional `string`, must be set to `"NHWC"`. Defaults to
+ `"NHWC"`.
+ Specify the data format of the input and output data.
+ output_dtype: An optional `tf.DType` from: `tf.int32, tf.int64`.
+ Defaults to `tf.int64`.
+ The dtype of the returned argmax tensor.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tuple of `Tensor` objects (output, argmax).
+
+ output: A `Tensor`. Has the same type as `input`.
+ argmax: A `Tensor` of type `output_dtype`.
+ """
+
+ if data_format != "NHWC":
+ raise ValueError("Data formats other than 'NHWC' are not yet supported")
+
+ return gen_nn_ops.max_pool_with_argmax(input=input,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ Targmax=output_dtype,
+ name=name)
+
+# pylint: enable=redefined-builtin
+
+
@ops.RegisterStatistics("Conv2D", "flops")
def _calc_conv_flops(graph, node):
"""Calculates the compute resources needed for Conv2D."""
@@ -2674,7 +2856,7 @@
return noise_shape
-@tf_export("nn.dropout")
+@tf_export(v1=["nn.dropout"])
def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
"""Computes dropout.
@@ -2733,8 +2915,74 @@
if tensor_util.constant_value(keep_prob) == 1:
return x
+ rate = 1 - keep_prob
+
+ return dropout_v2(x, rate, noise_shape=noise_shape, seed=seed, name=name)
+
+
+@tf_export("nn.dropout", v1=[])
+def dropout_v2(x, rate, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
+ """Computes dropout.
+
+ With probability `rate`, drops elements of `x`. Input that are kept are
+ scaled up by `1 / (1 - rate)`, otherwise outputs `0`. The scaling is so that
+ the expected sum is unchanged.
+
+ By default, each element is kept or dropped independently. If `noise_shape`
+ is specified, it must be
+ [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+ to the shape of `x`, and only dimensions with `noise_shape[i] == shape(x)[i]`
+ will make independent decisions. For example, if `shape(x) = [k, l, m, n]`
+ and `noise_shape = [k, 1, 1, n]`, each batch and channel component will be
+ kept independently and each row and column will be kept or not kept together.
+
+ Args:
+ x: A floating point tensor.
+ rate: A scalar `Tensor` with the same type as x. The probability
+ that each element is dropped. For example, setting rate=0.1 would drop
+ 10% of input elements.
+ noise_shape: A 1-D `Tensor` of type `int32`, representing the
+ shape for randomly generated keep/drop flags.
+ seed: A Python integer. Used to create random seeds. See
+ `tf.set_random_seed`
+ for behavior.
+ name: A name for this operation (optional).
+
+ Returns:
+ A Tensor of the same shape of `x`.
+
+ Raises:
+ ValueError: If `keep_prob` is not in `(0, 1]` or if `x` is not a floating
+ point tensor.
+ """
+ with ops.name_scope(name, "dropout", [x]) as name:
+ x = ops.convert_to_tensor(x, name="x")
+ if not x.dtype.is_floating:
+ raise ValueError("x has to be a floating point tensor since it's going to"
+ " be scaled. Got a %s tensor instead." % x.dtype)
+ if isinstance(rate, numbers.Real) and not 0 <= rate < 1:
+ raise ValueError("rate must be a scalar tensor or a float in the "
+ "range [0, 1), got %g" % rate)
+
+ # Early return if nothing needs to be dropped.
+ if isinstance(rate, float) and rate == 0:
+ return x
+ if context.executing_eagerly():
+ if isinstance(rate, ops.EagerTensor):
+ if rate.numpy() == 0:
+ return x
+ else:
+ rate = ops.convert_to_tensor(
+ rate, dtype=x.dtype, name="rate")
+ rate.get_shape().assert_is_compatible_with(tensor_shape.scalar())
+
+ # Do nothing if we know rate == 0
+ if tensor_util.constant_value(rate) == 0:
+ return x
+
noise_shape = _get_noise_shape(x, noise_shape)
+ keep_prob = 1 - rate
# uniform [keep_prob, 1.0 + keep_prob)
random_tensor = keep_prob
random_tensor += random_ops.random_uniform(
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 14cc1c6..96b9d6f 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -53,31 +53,29 @@
x_shape = [5, 17]
x_np = np.random.randint(0, 2, size=x_shape).astype(np.float32)
y_np = self._ZeroFraction(x_np)
- with self.cached_session():
- x_tf = constant_op.constant(x_np)
- x_tf.set_shape(x_shape)
- y_tf = nn_impl.zero_fraction(x_tf)
- y_tf_np = self.evaluate(y_tf)
+
+ x_tf = constant_op.constant(x_np)
+ x_tf.set_shape(x_shape)
+ y_tf = nn_impl.zero_fraction(x_tf)
+ y_tf_np = self.evaluate(y_tf)
+
eps = 1e-8
self.assertAllClose(y_tf_np, y_np, eps)
def testZeroFractionEmpty(self):
- with self.cached_session():
- x = np.zeros(0)
- y = nn_impl.zero_fraction(x).eval()
- self.assertTrue(np.isnan(y))
+ x = np.zeros(0)
+ y = self.evaluate(nn_impl.zero_fraction(x))
+ self.assertTrue(np.isnan(y))
def testZeroFraction2_27Zeros(self):
sparsity = nn_impl.zero_fraction(
array_ops.zeros([int(2**27 * 1.01)], dtype=dtypes.int8))
- with self.cached_session():
- self.assertAllClose(1.0, self.evaluate(sparsity))
+ self.assertAllClose(1.0, self.evaluate(sparsity))
def testZeroFraction2_27Ones(self):
sparsity = nn_impl.zero_fraction(
array_ops.ones([int(2**27 * 1.01)], dtype=dtypes.int8))
- with self.cached_session():
- self.assertAllClose(0.0, self.evaluate(sparsity))
+ self.assertAllClose(0.0, self.evaluate(sparsity))
def testUnknownSize(self):
value = array_ops.placeholder(dtype=dtypes.float32)
@@ -302,19 +300,18 @@
y_dim = 30
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.cached_session():
- t = constant_op.constant(
- 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
- dropout = nn_ops.dropout(t, keep_prob)
- final_count = 0
- self.assertEqual([x_dim, y_dim], dropout.get_shape())
- for _ in xrange(0, num_iter):
- value = self.evaluate(dropout)
- final_count += np.count_nonzero(value)
- # Verifies that there are only two values: 0 and 1/keep_prob.
- sorted_value = np.unique(np.sort(value))
- self.assertEqual(0, sorted_value[0])
- self.assertAllClose(1 / keep_prob, sorted_value[1])
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ dropout = nn_ops.dropout(t, keep_prob)
+ final_count = 0
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = self.evaluate(dropout)
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+
# Check that we are in the 15% error range
expected_count = x_dim * y_dim * keep_prob * num_iter
rel_error = math.fabs(final_count - expected_count) / expected_count
@@ -330,19 +327,18 @@
y_dim = 3
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.cached_session():
- t = constant_op.constant(
- 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
- dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
- self.assertEqual([x_dim, y_dim], dropout.get_shape())
- final_count = 0
- for _ in xrange(0, num_iter):
- value = self.evaluate(dropout)
- final_count += np.count_nonzero(value)
- # Verifies that there are only two values: 0 and 1/keep_prob.
- sorted_value = np.unique(np.sort(value))
- self.assertEqual(0, sorted_value[0])
- self.assertAllClose(1 / keep_prob, sorted_value[1])
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ final_count = 0
+ for _ in xrange(0, num_iter):
+ value = self.evaluate(dropout)
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+
# Check that we are in the 15% error range
expected_count = x_dim * y_dim * keep_prob * num_iter
rel_error = math.fabs(final_count - expected_count) / expected_count
@@ -355,17 +351,15 @@
y_dim = 30
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.cached_session():
- t = constant_op.constant(
- 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
- dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
- self.assertEqual([x_dim, y_dim], dropout.get_shape())
- for _ in xrange(0, num_iter):
- value = self.evaluate(dropout)
- # Verifies that each y column as only one type of activation.
- for i in xrange(x_dim):
- sorted_value = np.unique(np.sort(value[i, :]))
- self.assertEqual(sorted_value.size, 1)
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ dropout = nn_ops.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ for _ in xrange(0, num_iter):
+ value = self.evaluate(dropout)
+ # Verifies that each y column as only one type of activation.
+ for i in xrange(x_dim):
+ sorted_value = np.unique(np.sort(value[i, :]))
+ self.assertEqual(sorted_value.size, 1)
def testDropoutPlaceholderKeepProb(self):
# Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
@@ -409,20 +403,19 @@
y_dim = 3
num_iter = 10
for keep_prob in [0.1, 0.5, 0.8]:
- with self.cached_session():
- t = constant_op.constant(
- 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
- # Set noise_shape=[None, 1] which means [x_dim, 1].
- dropout = nn_ops.dropout(t, keep_prob, noise_shape=[None, 1])
- self.assertEqual([x_dim, y_dim], dropout.get_shape())
- final_count = 0
- for _ in xrange(0, num_iter):
- value = self.evaluate(dropout)
- final_count += np.count_nonzero(value)
- # Verifies that there are only two values: 0 and 1/keep_prob.
- sorted_value = np.unique(np.sort(value))
- self.assertEqual(0, sorted_value[0])
- self.assertAllClose(1 / keep_prob, sorted_value[1])
+ t = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=dtypes.float32)
+ # Set noise_shape=[None, 1] which means [x_dim, 1].
+ dropout = nn_ops.dropout(t, keep_prob, noise_shape=[None, 1])
+ self.assertEqual([x_dim, y_dim], dropout.get_shape())
+ final_count = 0
+ for _ in xrange(0, num_iter):
+ value = self.evaluate(dropout)
+ final_count += np.count_nonzero(value)
+ # Verifies that there are only two values: 0 and 1/keep_prob.
+ sorted_value = np.unique(np.sort(value))
+ self.assertEqual(0, sorted_value[0])
+ self.assertAllClose(1 / keep_prob, sorted_value[1])
+
# Check that we are in the 15% error range
expected_count = x_dim * y_dim * keep_prob * num_iter
rel_error = math.fabs(final_count - expected_count) / expected_count
@@ -563,78 +556,78 @@
initializer=constant_op.constant(biases))
with self.session(graph=g) as sess:
variables.global_variables_initializer().run()
- return sess.run([list(sharded_weights), list(sharded_biases)])
+ return self.evaluate([list(sharded_weights), list(sharded_biases)])
def testShapes(self):
np.random.seed(0)
num_classes = 5
batch_size = 3
- with self.cached_session() as sess:
- for num_true in range(1, 5):
- labels = np.random.randint(
- low=0, high=num_classes, size=batch_size * num_true)
- (weights, biases, hidden_acts, sampled_vals, exp_logits,
- exp_labels) = self._GenerateTestData(
- num_classes=num_classes,
- dim=10,
- batch_size=batch_size,
- num_true=num_true,
- labels=labels,
- sampled=[1, 0, 2, 3],
- subtract_log_q=False)
- logits_tensor, labels_tensor = _compute_sampled_logits(
- weights=constant_op.constant(weights),
- biases=constant_op.constant(biases),
- labels=constant_op.constant(
- labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=4,
- num_classes=num_classes,
- num_true=num_true,
- sampled_values=sampled_vals,
- subtract_log_q=False,
- remove_accidental_hits=False,
- partition_strategy="div",
- name="sampled_logits_basic_num_true_%d" % num_true)
- got_logits, got_labels = sess.run([logits_tensor, labels_tensor])
- self.assertEqual(exp_logits.shape, got_logits.shape, self._eps)
- self.assertEqual(exp_labels.shape, got_labels.shape, self._eps)
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=False)
+ logits_tensor, labels_tensor = _compute_sampled_logits(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ partition_strategy="div",
+ name="sampled_logits_basic_num_true_%d" % num_true)
+ got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor])
+ self.assertEqual(exp_logits.shape, got_logits.shape, self._eps)
+ self.assertEqual(exp_labels.shape, got_labels.shape, self._eps)
def testBasic(self):
"""Without accidental hit removal or subtract_log_q."""
np.random.seed(0)
num_classes = 5
batch_size = 3
- with self.cached_session() as sess:
- for num_true in range(1, 5):
- labels = np.random.randint(
- low=0, high=num_classes, size=batch_size * num_true)
- (weights, biases, hidden_acts, sampled_vals, exp_logits,
- exp_labels) = self._GenerateTestData(
- num_classes=num_classes,
- dim=10,
- batch_size=batch_size,
- num_true=num_true,
- labels=labels,
- sampled=[1, 0, 2, 3],
- subtract_log_q=False)
- logits_tensor, labels_tensor = _compute_sampled_logits(
- weights=constant_op.constant(weights),
- biases=constant_op.constant(biases),
- labels=constant_op.constant(
- labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=4,
- num_classes=num_classes,
- num_true=num_true,
- sampled_values=sampled_vals,
- subtract_log_q=False,
- remove_accidental_hits=False,
- partition_strategy="div",
- name="sampled_logits_basic_num_true_%d" % num_true)
- got_logits, got_labels = sess.run([logits_tensor, labels_tensor])
- self.assertAllClose(exp_logits, got_logits, self._eps)
- self.assertAllClose(exp_labels, got_labels, self._eps)
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=False)
+ logits_tensor, labels_tensor = _compute_sampled_logits(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ partition_strategy="div",
+ name="sampled_logits_basic_num_true_%d" % num_true)
+ got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor])
+ self.assertAllClose(exp_logits, got_logits, self._eps)
+ self.assertAllClose(exp_labels, got_labels, self._eps)
def testAccidentalHitRemoval(self):
"""With accidental hit removal, no subtract_log_q."""
@@ -642,118 +635,118 @@
num_classes = 5
batch_size = 3
sampled = [1, 0, 2, 3]
- with self.cached_session():
- for num_true in range(1, 5):
- labels = np.random.randint(
- low=0, high=num_classes, size=batch_size * num_true)
- (weights, biases, hidden_acts, sampled_vals, _,
- _) = self._GenerateTestData(
- num_classes=num_classes,
- dim=10,
- batch_size=batch_size,
- num_true=num_true,
- labels=labels,
- sampled=sampled,
- subtract_log_q=False)
- logits_tensor, _ = _compute_sampled_logits(
- weights=constant_op.constant(weights),
- biases=constant_op.constant(biases),
- labels=constant_op.constant(
- labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=len(sampled),
- num_classes=num_classes,
- num_true=num_true,
- sampled_values=sampled_vals,
- subtract_log_q=False,
- remove_accidental_hits=True,
- partition_strategy="div",
- name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true)
- # Test that the exponentiated logits of accidental hits are near 0.
- # First we need to find the hits in this random test run:
- labels_reshape = labels.reshape((batch_size, num_true))
- got_logits = self.evaluate(logits_tensor)
- for row in xrange(batch_size):
- row_labels = labels_reshape[row, :]
- for col in xrange(len(sampled)):
- if sampled[col] in row_labels:
- # We need to add the num_true_test offset into logits_*
- self.assertNear(
- np.exp(got_logits[row, col + num_true]), 0., self._eps)
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, _,
+ _) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=sampled,
+ subtract_log_q=False)
+ logits_tensor, _ = _compute_sampled_logits(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=len(sampled),
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=True,
+ partition_strategy="div",
+ name="sampled_logits_accidental_hit_removal_num_true_%d" % num_true)
+ # Test that the exponentiated logits of accidental hits are near 0.
+ # First we need to find the hits in this random test run:
+ labels_reshape = labels.reshape((batch_size, num_true))
+ got_logits = self.evaluate(logits_tensor)
+ for row in xrange(batch_size):
+ row_labels = labels_reshape[row, :]
+ for col in xrange(len(sampled)):
+ if sampled[col] in row_labels:
+ # We need to add the num_true_test offset into logits_*
+ self.assertNear(
+ np.exp(got_logits[row, col + num_true]), 0., self._eps)
def testSubtractLogQ(self):
"""With subtract_log_q, no accidental hit removal."""
np.random.seed(0)
num_classes = 5
batch_size = 3
- with self.cached_session() as sess:
- for num_true in range(1, 5):
- labels = np.random.randint(
- low=0, high=num_classes, size=batch_size * num_true)
- (weights, biases, hidden_acts, sampled_vals, exp_logits,
- exp_labels) = self._GenerateTestData(
- num_classes=num_classes,
- dim=10,
- batch_size=batch_size,
- num_true=num_true,
- labels=labels,
- sampled=[1, 0, 2, 3],
- subtract_log_q=True)
- logits_tensor, labels_tensor = _compute_sampled_logits(
- weights=constant_op.constant(weights),
- biases=constant_op.constant(biases),
- labels=constant_op.constant(
- labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=4,
- num_classes=num_classes,
- num_true=num_true,
- sampled_values=sampled_vals,
- subtract_log_q=True,
- remove_accidental_hits=False,
- partition_strategy="div",
- name="sampled_logits_subtract_log_q_num_true_%d" % num_true)
- got_logits, got_labels = sess.run([logits_tensor, labels_tensor])
- self.assertAllClose(exp_logits, got_logits, self._eps)
- self.assertAllClose(exp_labels, got_labels, self._eps)
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=True)
+ logits_tensor, labels_tensor = _compute_sampled_logits(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=True,
+ remove_accidental_hits=False,
+ partition_strategy="div",
+ name="sampled_logits_subtract_log_q_num_true_%d" % num_true)
+ got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor])
+ self.assertAllClose(exp_logits, got_logits, self._eps)
+ self.assertAllClose(exp_labels, got_labels, self._eps)
def testSharded(self):
"""With sharded weights and sharded biases."""
np.random.seed(0)
num_classes = 5
batch_size = 3
- with self.cached_session() as sess:
- for num_true in range(1, 5):
- labels = np.random.randint(
- low=0, high=num_classes, size=batch_size * num_true)
- (weights, biases, hidden_acts, sampled_vals, exp_logits,
- exp_labels) = self._GenerateTestData(
- num_classes=num_classes,
- dim=10,
- batch_size=batch_size,
- num_true=num_true,
- labels=labels,
- sampled=[1, 0, 2, 3],
- subtract_log_q=False)
- weight_shards, bias_shards = self._ShardTestEmbeddings(
- weights, biases, num_shards=3)
- logits_tensor, labels_tensor = _compute_sampled_logits(
- weights=[constant_op.constant(shard) for shard in weight_shards],
- biases=[constant_op.constant(shard) for shard in bias_shards],
- labels=constant_op.constant(
- labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=4,
- num_classes=num_classes,
- num_true=num_true,
- sampled_values=sampled_vals,
- subtract_log_q=False,
- remove_accidental_hits=False,
- partition_strategy="div",
- name="sampled_logits_sharded_num_true_%d" % num_true)
- got_logits, got_labels = sess.run([logits_tensor, labels_tensor])
- self.assertAllClose(exp_logits, got_logits, self._eps)
- self.assertAllClose(exp_labels, got_labels, self._eps)
+
+ for num_true in range(1, 5):
+ labels = np.random.randint(
+ low=0, high=num_classes, size=batch_size * num_true)
+ (weights, biases, hidden_acts, sampled_vals, exp_logits,
+ exp_labels) = self._GenerateTestData(
+ num_classes=num_classes,
+ dim=10,
+ batch_size=batch_size,
+ num_true=num_true,
+ labels=labels,
+ sampled=[1, 0, 2, 3],
+ subtract_log_q=False)
+ weight_shards, bias_shards = self._ShardTestEmbeddings(
+ weights, biases, num_shards=3)
+ logits_tensor, labels_tensor = _compute_sampled_logits(
+ weights=[constant_op.constant(shard) for shard in weight_shards],
+ biases=[constant_op.constant(shard) for shard in bias_shards],
+ labels=constant_op.constant(
+ labels, dtype=dtypes.int64, shape=(batch_size, num_true)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_vals,
+ subtract_log_q=False,
+ remove_accidental_hits=False,
+ partition_strategy="div",
+ name="sampled_logits_sharded_num_true_%d" % num_true)
+ got_logits, got_labels = self.evaluate([logits_tensor, labels_tensor])
+ self.assertAllClose(exp_logits, got_logits, self._eps)
+ self.assertAllClose(exp_labels, got_labels, self._eps)
def testNCELoss(self):
# A simple test to verify the numerics.
@@ -782,35 +775,34 @@
exp_nce_loss = np.sum(
_SigmoidCrossEntropyWithLogits(exp_logits, exp_labels), 1)
- with self.cached_session():
- got_nce_loss = nn_impl.nce_loss(
- weights=constant_op.constant(weights),
- biases=constant_op.constant(biases),
- labels=constant_op.constant(labels, shape=(batch_size, 1)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=4,
- num_classes=num_classes,
- num_true=1,
- sampled_values=sampled_vals,
- partition_strategy="div")
+ got_nce_loss = nn_impl.nce_loss(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(labels, shape=(batch_size, 1)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals,
+ partition_strategy="div")
- self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4)
+ self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4)
- # Test with sharded weights and sharded biases.
- weight_shards, bias_shards = self._ShardTestEmbeddings(
- weights, biases, num_shards=3)
- got_nce_loss = nn_impl.nce_loss(
- weights=[constant_op.constant(shard) for shard in weight_shards],
- biases=[constant_op.constant(shard) for shard in bias_shards],
- labels=constant_op.constant(labels, shape=(batch_size, 1)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=4,
- num_classes=num_classes,
- num_true=1,
- sampled_values=sampled_vals,
- partition_strategy="div")
+ # Test with sharded weights and sharded biases.
+ weight_shards, bias_shards = self._ShardTestEmbeddings(
+ weights, biases, num_shards=3)
+ got_nce_loss = nn_impl.nce_loss(
+ weights=[constant_op.constant(shard) for shard in weight_shards],
+ biases=[constant_op.constant(shard) for shard in bias_shards],
+ labels=constant_op.constant(labels, shape=(batch_size, 1)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals,
+ partition_strategy="div")
- self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4)
+ self.assertAllClose(exp_nce_loss, self.evaluate(got_nce_loss), 1e-4)
def testSampledSoftmaxLoss(self):
# A simple test to verify the numerics.
@@ -839,39 +831,38 @@
exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
exp_logits, exp_labels)
- with self.cached_session():
- got_sampled_softmax_loss = nn_impl.sampled_softmax_loss(
- weights=constant_op.constant(weights),
- biases=constant_op.constant(biases),
- labels=constant_op.constant(labels, shape=(batch_size, 1)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=4,
- num_classes=num_classes,
- num_true=1,
- sampled_values=sampled_vals,
- remove_accidental_hits=False,
- partition_strategy="div")
+ got_sampled_softmax_loss = nn_impl.sampled_softmax_loss(
+ weights=constant_op.constant(weights),
+ biases=constant_op.constant(biases),
+ labels=constant_op.constant(labels, shape=(batch_size, 1)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals,
+ remove_accidental_hits=False,
+ partition_strategy="div")
- self.assertAllClose(exp_sampled_softmax_loss,
- self.evaluate(got_sampled_softmax_loss), 1e-4)
+ self.assertAllClose(exp_sampled_softmax_loss,
+ self.evaluate(got_sampled_softmax_loss), 1e-4)
- # Test with sharded weights and sharded biases.
- weight_shards, bias_shards = self._ShardTestEmbeddings(
- weights, biases, num_shards=3)
- got_sampled_softmax_loss = nn_impl.sampled_softmax_loss(
- weights=[constant_op.constant(shard) for shard in weight_shards],
- biases=[constant_op.constant(shard) for shard in bias_shards],
- labels=constant_op.constant(labels, shape=(batch_size, 1)),
- inputs=constant_op.constant(hidden_acts),
- num_sampled=4,
- num_classes=num_classes,
- num_true=1,
- sampled_values=sampled_vals,
- remove_accidental_hits=False,
- partition_strategy="div")
+ # Test with sharded weights and sharded biases.
+ weight_shards, bias_shards = self._ShardTestEmbeddings(
+ weights, biases, num_shards=3)
+ got_sampled_softmax_loss = nn_impl.sampled_softmax_loss(
+ weights=[constant_op.constant(shard) for shard in weight_shards],
+ biases=[constant_op.constant(shard) for shard in bias_shards],
+ labels=constant_op.constant(labels, shape=(batch_size, 1)),
+ inputs=constant_op.constant(hidden_acts),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals,
+ remove_accidental_hits=False,
+ partition_strategy="div")
- self.assertAllClose(exp_sampled_softmax_loss,
- self.evaluate(got_sampled_softmax_loss), 1e-4)
+ self.assertAllClose(exp_sampled_softmax_loss,
+ self.evaluate(got_sampled_softmax_loss), 1e-4)
def testSampledSoftmaxLossBf16(self):
# A simple test to verify the numerics for bfloat16.
@@ -900,29 +891,30 @@
exp_sampled_softmax_loss = _SoftmaxCrossEntropyWithLogits(
exp_logits, exp_labels)
- with self.cached_session():
- true_exp_bf16 = np.full(
- [batch_size, 1], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
- sampled_exp_bf16 = np.full(
- [len(sampled)], fill_value=0.5, dtype=dtypes.bfloat16.as_numpy_dtype)
- sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16)
+ true_exp_bf16 = np.full([batch_size, 1],
+ fill_value=0.5,
+ dtype=dtypes.bfloat16.as_numpy_dtype)
+ sampled_exp_bf16 = np.full([len(sampled)],
+ fill_value=0.5,
+ dtype=dtypes.bfloat16.as_numpy_dtype)
+ sampled_vals_bf16 = (sampled, true_exp_bf16, sampled_exp_bf16)
- got_sampled_softmax_loss = math_ops.cast(
- nn_impl.sampled_softmax_loss(
- weights=constant_op.constant(weights, dtype=dtypes.bfloat16),
- biases=constant_op.constant(biases, dtype=dtypes.bfloat16),
- labels=constant_op.constant(
- labels, shape=(batch_size, 1), dtype=dtypes.bfloat16),
- inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16),
- num_sampled=4,
- num_classes=num_classes,
- num_true=1,
- sampled_values=sampled_vals_bf16,
- remove_accidental_hits=False,
- partition_strategy="div"), dtypes.float32)
+ got_sampled_softmax_loss = math_ops.cast(
+ nn_impl.sampled_softmax_loss(
+ weights=constant_op.constant(weights, dtype=dtypes.bfloat16),
+ biases=constant_op.constant(biases, dtype=dtypes.bfloat16),
+ labels=constant_op.constant(
+ labels, shape=(batch_size, 1), dtype=dtypes.bfloat16),
+ inputs=constant_op.constant(hidden_acts, dtype=dtypes.bfloat16),
+ num_sampled=4,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_vals_bf16,
+ remove_accidental_hits=False,
+ partition_strategy="div"), dtypes.float32)
- self.assertAllClose(exp_sampled_softmax_loss,
- self.evaluate(got_sampled_softmax_loss), 1e-1)
+ self.assertAllClose(exp_sampled_softmax_loss,
+ self.evaluate(got_sampled_softmax_loss), 1e-1)
class CReluTest(test_lib.TestCase):
@@ -931,9 +923,9 @@
np.random.seed(1) # Make it reproducible.
x = np.random.randn(3, 4).astype(np.float32)
y = np.concatenate([x * (x > 0), -x * (x < 0)], axis=1)
- with self.cached_session():
- z = nn_ops.crelu(constant_op.constant(x)).eval()
- self.assertAllClose(y, z, 1e-4)
+
+ z = self.evaluate(nn_ops.crelu(constant_op.constant(x)))
+ self.assertAllClose(y, z, 1e-4)
class ReluTest(test_lib.TestCase):
@@ -942,9 +934,9 @@
np.random.seed(1) # Make it reproducible.
x = np.random.randn(3, 4).astype(np.float32)
y = np.maximum(x, 0.0)
- with self.cached_session():
- z = nn_ops.relu(constant_op.constant(x)).eval()
- self.assertAllEqual(y, z)
+
+ z = self.evaluate(nn_ops.relu(constant_op.constant(x)))
+ self.assertAllEqual(y, z)
def testNaNs(self):
# Test that relu(nan) = nan for various sizes.
@@ -967,8 +959,9 @@
outputs = nn_ops.leaky_relu(inputs)
self.assertEquals(inputs.shape, outputs.shape)
- with self.cached_session() as sess:
- inputs, outputs = sess.run([inputs, outputs])
+
+ inputs, outputs = self.evaluate([inputs, outputs])
+
self.assertGreaterEqual(outputs.min(), 0.0)
self.assertLessEqual(outputs.max(), 1.0)
self.assertAllClose(inputs, outputs)
@@ -977,8 +970,9 @@
for dtype in [np.int32, np.int64, np.float16, np.float32, np.float64]:
np_values = np.array([-2, -1, 0, 1, 2], dtype=dtype)
outputs = nn_ops.leaky_relu(constant_op.constant(np_values))
- with self.cached_session() as sess:
- outputs = self.evaluate(outputs)
+
+ outputs = self.evaluate(outputs)
+
tol = 2e-3 if dtype == np.float16 else 1e-6
self.assertAllClose(
outputs, [-0.4, -0.2, 0.0, 1.0, 2.0], rtol=tol, atol=tol)
@@ -1004,9 +998,10 @@
tf_values = constant_op.constant(np_values)
actual_tf_outputs = nn_impl.swish(tf_values)
expected_tf_outputs = tf_values * math_ops.sigmoid(tf_values)
- with self.cached_session() as sess:
- actual_outputs, expected_outputs = sess.run(
- [actual_tf_outputs, expected_tf_outputs])
+
+ actual_outputs, expected_outputs = self.evaluate(
+ [actual_tf_outputs, expected_tf_outputs])
+
self.assertAllClose(actual_outputs, expected_outputs)
def testGradients(self):
@@ -1051,7 +1046,7 @@
self.assertLess(err, 1e-3)
# Evaluate.
- [mean, variance] = sess.run([mean, variance])
+ [mean, variance] = self.evaluate([mean, variance])
# Make sure that there are no NaNs
self.assertFalse(np.isnan(mean).any())
self.assertFalse(np.isnan(variance).any())
@@ -1094,9 +1089,9 @@
def _test(self, x_val, y_val_expected):
x = constant_op.constant(x_val)
y = nn_ops.data_format_dim_map(x)
- with self.cached_session(use_gpu=test_lib.is_gpu_available()) as sess:
- y_val = self.evaluate(y)
- self.assertAllEqual(y_val, y_val_expected)
+
+ y_val = self.evaluate(y)
+ self.assertAllEqual(y_val, y_val_expected)
def test(self):
self._test(0, 0)
@@ -1117,7 +1112,7 @@
y_val_expected = [2, 2, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="NCHW")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, y_val_expected)
@@ -1126,7 +1121,7 @@
y_val_expected = [2, 0, 1, 3, 2, 0, 1, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="HWNC")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, y_val_expected)
@@ -1135,7 +1130,7 @@
y_val_expected = [3, 1, 0, 2, 3, 1, 0, 2]
x = constant_op.constant(x_val)
y = nn_ops.data_format_dim_map(x, src_format="NHWC", dst_format="WHCN")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, y_val_expected)
@@ -1144,7 +1139,7 @@
y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0]
x = constant_op.constant(x_val)
y = nn_ops.data_format_dim_map(x, src_format="qwer", dst_format="rewq")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, y_val_expected)
@@ -1155,7 +1150,7 @@
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x)
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [7, 3, 4, 9])
@@ -1163,7 +1158,7 @@
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [7, 9, 3, 4])
@@ -1171,7 +1166,7 @@
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [4, 9, 7, 3])
@@ -1179,7 +1174,7 @@
x_val = [7, 4, 9, 3]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [9, 7, 4, 3])
@@ -1187,7 +1182,7 @@
x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x)
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [[7, 4], [5, 1], [9, 3], [4, 5]])
@@ -1195,7 +1190,7 @@
x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x, src_format="NHWC", dst_format="HWNC")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [[9, 3], [4, 5], [7, 4], [5, 1]])
@@ -1203,7 +1198,7 @@
x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x, src_format="HWNC", dst_format="NHWC")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [[4, 5], [7, 4], [9, 3], [5, 1]])
@@ -1211,7 +1206,7 @@
x_val = [[7, 4], [9, 3], [4, 5], [5, 1]]
x = constant_op.constant(x_val)
y = nn_ops.data_format_vec_permute(x, src_format="NCHW", dst_format="NHWC")
- with self.session(use_gpu=test_lib.is_gpu_available()) as sess:
+ with test_util.use_gpu():
y_val = self.evaluate(y)
self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]])
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index ead7ae5..8f652e9 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -17,16 +17,20 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.parallel_for.pfor import PFor
from tensorflow.python.util import nest
-def for_loop(loop_fn, loop_fn_dtypes, iters):
+def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
"""Runs `loop_fn` `iters` times and stacks the outputs.
@@ -39,6 +43,8 @@
objects. The shape of these outputs should not depend on the input.
loop_fn_dtypes: dtypes for the outputs of loop_fn.
iters: Number of iterations for which to run loop_fn.
+ parallel_iterations: The number of iterations that can be dispatched in
+ parallel. This knob can be used to control the total memory usage.
Returns:
Returns a nested structure of stacked output tensor objects with the same
@@ -66,11 +72,16 @@
outputs.append(ta)
return tuple([i + 1] + outputs)
+ if parallel_iterations is not None:
+ extra_args = {"parallel_iterations": parallel_iterations}
+ else:
+ extra_args = {}
ta_list = control_flow_ops.while_loop(
- lambda i, *ta: i < iters, while_body, [0] + [
- tensor_array_ops.TensorArray(dtype, iters)
- for dtype in flat_loop_fn_dtypes
- ])[1:]
+ lambda i, *ta: i < iters,
+ while_body,
+ [0] + [tensor_array_ops.TensorArray(dtype, iters)
+ for dtype in flat_loop_fn_dtypes],
+ **extra_args)[1:]
# TODO(rachelim): enable this for sparse tensors
@@ -79,7 +90,15 @@
return nest.pack_sequence_as(loop_fn_dtypes, output)
-def pfor(loop_fn, iters):
+def _flatten_first_two_dims(x):
+ """Flattens the first two dimensions of x into a single dimension."""
+ old_shape = array_ops.shape(x)
+ new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]],
+ axis=0)
+ return array_ops.reshape(x, new_shape)
+
+
+def pfor(loop_fn, iters, parallel_iterations=None):
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs.
`pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
@@ -99,8 +118,8 @@
reads, etc).
- Conversion works only on a limited set of kernels for which a converter
has been registered.
- - loop_fn cannot currently contain control flow operations like
- tf.while_loop or tf.cond.
+ - loop_fn has limited support for control flow operations. tf.cond in
+ particular is not supported.
- `loop_fn` should return nested structure of Tensors or Operations. However
if an Operation is returned, it should have zero outputs.
- The shape and dtype of `loop_fn` outputs should not depend on the input
@@ -109,22 +128,92 @@
Args:
loop_fn: A function that takes an int32 scalar tf.Tensor object representing
the iteration number, and returns a possibly nested structure of Tensor or
- Operation objects.
+ Operation objects. Note that if setting `parallel_iterations` argument to
+ something other than None, `loop_fn` may be called more than once during
+ graph construction. So it may need to avoid mutating global state.
iters: Number of iterations for which to run loop_fn.
+ parallel_iterations: A knob to control how many iterations are vectorized
+ and dispatched in parallel. The default value of None corresponds to
+ vectorizing all the iterations. If `parallel_iterations` is smaller than
+ `iters`, then chunks of at most that many iterations are dispatched in
+ sequence. This knob can be used to control the total memory usage.
Returns:
Returns a nested structure of stacked tensor objects with the same nested
structure as the output of `loop_fn`.
+ Raises:
+ ValueError: If parallel_iterations is not None and not an integer > 1.
"""
+ def f():
+ return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
+ if context.executing_eagerly():
+ f = function.defun(f)
+ return f()
+
+
+def _pfor_impl(loop_fn, iters, parallel_iterations=None):
+ """Implementation of pfor."""
existing_ops = set(ops.get_default_graph().get_operations())
with ops.name_scope("loop_body"):
loop_var = array_ops.placeholder(dtypes.int32, shape=[])
loop_fn_outputs = loop_fn(loop_var)
new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
iters = ops.convert_to_tensor(iters)
- with ops.name_scope("pfor"):
- converter = PFor(loop_var, iters, new_ops)
- outputs = []
- for loop_fn_output in nest.flatten(loop_fn_outputs):
- outputs.append(converter.convert(loop_fn_output))
- return nest.pack_sequence_as(loop_fn_outputs, outputs)
+ if parallel_iterations is not None:
+ if parallel_iterations < 1:
+ raise ValueError("parallel_iterations must be None or a positive integer")
+ if parallel_iterations == 1:
+ raise ValueError("Found parallel_iterations == 1. Use for_loop instead.")
+ iters_value = tensor_util.constant_value(iters)
+ if iters_value is not None and iters_value < parallel_iterations:
+ parallel_iterations = None
+ if parallel_iterations is None:
+ with ops.name_scope("pfor"):
+ converter = PFor(loop_var, iters, new_ops)
+ outputs = []
+ for loop_fn_output in nest.flatten(loop_fn_outputs):
+ outputs.append(converter.convert(loop_fn_output))
+ return nest.pack_sequence_as(loop_fn_outputs, outputs)
+ else:
+ num_tiled_iterations = iters // parallel_iterations
+ num_remaining_iterations = iters % parallel_iterations
+ # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
+ # a tf.function and extract the graph from there to vectorize it.
+ with ops.name_scope("pfor_untiled"):
+ converter = PFor(loop_var, num_remaining_iterations, new_ops)
+ remaining_outputs = []
+ flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
+ for loop_fn_output in flattened_loop_fn_outputs:
+ remaining_outputs.append(converter.convert(loop_fn_output))
+
+ with ops.name_scope("pfor_tiled"):
+ loop_fn_dtypes = [ops.convert_to_tensor(x).dtype
+ for x in flattened_loop_fn_outputs]
+
+ def tiled_loop_body(j):
+ offset = j * parallel_iterations + num_remaining_iterations
+
+ def tiled_loop_fn(i):
+ return nest.flatten(loop_fn(i + offset))
+
+ return pfor(tiled_loop_fn, parallel_iterations)
+
+ tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes,
+ num_tiled_iterations, parallel_iterations=1)
+ tiled_outputs = [_flatten_first_two_dims(y) for y in tiled_outputs]
+
+ with ops.name_scope("pfor"):
+ iters_value = tensor_util.constant_value(iters)
+ if iters_value is None or iters_value % parallel_iterations:
+ outputs = control_flow_ops.cond(
+ math_ops.equal(num_remaining_iterations, 0),
+ lambda: tiled_outputs,
+ lambda: [array_ops.concat([x, y], axis=0)
+ for x, y in zip(remaining_outputs, tiled_outputs)])
+ else:
+ outputs = tiled_outputs
+ return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
+
+
+
+
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index 171369b..c248476 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -26,10 +26,12 @@
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session
+from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import clip_ops
@@ -52,6 +54,7 @@
from tensorflow.python.util import nest
+@test_util.run_all_in_graph_and_eager_modes
class PForTest(test.TestCase):
def _run_targets(self, targets1, targets2=None, run_init=True):
@@ -73,9 +76,13 @@
else:
self.assertAllEqual(outputs[i + n], outputs[i])
- def _test_loop_fn(self, loop_fn, iters, loop_fn_dtypes=dtypes.float32):
- t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters)
- t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters)
+ def _test_loop_fn(self, loop_fn, iters,
+ loop_fn_dtypes=dtypes.float32,
+ parallel_iterations=None):
+ t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
+ parallel_iterations=parallel_iterations)
+ t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
+ parallel_iterations=parallel_iterations)
self.run_and_assert_equal(t1, t2)
def test_op_conversion_fallback_to_while_loop(self):
@@ -96,7 +103,32 @@
loop_fn, 3, loop_fn_dtypes=[dtypes.float32, dtypes.int32])
flags.FLAGS.op_conversion_fallback_to_while_loop = False
+ def test_parallel_iterations(self):
+ for parallel_iterations in [2, 3, 8, 10]:
+ x = random_ops.random_uniform([8, 3])
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ return array_ops.gather(x, i)
+ # pylint: enable=cell-var-from-loop
+
+ self._test_loop_fn(loop_fn, 8, parallel_iterations=parallel_iterations)
+ self._test_loop_fn(loop_fn, 4 * constant_op.constant(2),
+ parallel_iterations=parallel_iterations)
+
+ def test_parallel_iterations_zero(self):
+ with self.assertRaisesRegexp(ValueError, "positive integer"):
+ pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=0)
+ with self.assertRaisesRegexp(TypeError, "positive integer"):
+ pfor_control_flow_ops.for_loop(lambda i: 1, dtypes.int32, 8,
+ parallel_iterations=0)
+
+ def test_parallel_iterations_one(self):
+ with self.assertRaisesRegexp(ValueError, "Use for_loop instead"):
+ pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1)
+
+
+@test_util.run_all_in_graph_and_eager_modes
class ArrayTest(PForTest):
def test_gather(self):
@@ -288,14 +320,17 @@
def test_unary_cwise_ops(self):
for op in [array_ops.identity, array_ops.stop_gradient]:
- x = random_ops.random_uniform([3, 5])
+ with backprop.GradientTape(persistent=True) as g:
+ x = random_ops.random_uniform([3, 5])
+ g.watch(x)
# pylint: disable=cell-var-from-loop
def loop_fn(i):
- x1 = array_ops.gather(x, i)
- y = op(x1) + x1
- loss = nn.l2_loss(y)
- return op(x), y, gradient_ops.gradients(loss, x1)
+ with g:
+ x1 = array_ops.gather(x, i)
+ y = op(x1) + x1
+ loss = nn.l2_loss(y)
+ return op(x), y, g.gradient(loss, x1)
# pylint: enable=cell-var-from-loop
@@ -318,17 +353,21 @@
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32])
def test_strided_slice(self):
- x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2])
+ with backprop.GradientTape(persistent=True) as g:
+ x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2])
+ g.watch(x)
def loop_fn(i):
- x_i = array_ops.gather(x, i)
- y = x_i[:2, ::2, 1::3, ..., array_ops.newaxis, 1]
- loss = nn.l2_loss(y)
- return y, gradient_ops.gradients(loss, x_i)
+ with g:
+ x_i = array_ops.gather(x, i)
+ y = x_i[:2, ::2, 1::3, ..., array_ops.newaxis, 1]
+ loss = nn.l2_loss(y)
+ return y, g.gradient(loss, x_i)
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
+@test_util.run_all_in_graph_and_eager_modes
class BitwiseTest(PForTest):
def test_unary_cwise(self):
@@ -368,6 +407,7 @@
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=output_dtypes)
+@test_util.run_all_in_graph_and_eager_modes
class MathTest(PForTest):
def test_unary_cwise_ops(self):
@@ -424,22 +464,29 @@
nn.softsign,
]
for op in complex_ops + real_ops:
- x = random_ops.random_uniform([3, 5])
- if op in complex_ops:
- y = random_ops.random_uniform([3, 5])
- x = math_ops.complex(x, y)
+ with backprop.GradientTape(persistent=True) as g:
+ x = random_ops.random_uniform([3, 5])
+ g.watch(x)
+ if op in complex_ops:
+ y = random_ops.random_uniform([3, 5])
+ g.watch(y)
+ x = math_ops.complex(x, y)
# pylint: disable=cell-var-from-loop
output_dtypes = []
def loop_fn(i):
- x1 = array_ops.gather(x, i)
- y1 = op(x1)
- outputs = [op(x), y1]
- if y1.dtype == dtypes.float32:
- loss = math_ops.reduce_sum(y1 * y1)
- grad = gradient_ops.gradients(loss, x1)
- if grad and grad[0] is not None:
- outputs.extend(grad)
+ with g:
+ x1 = array_ops.gather(x, i)
+ y1 = op(x1)
+ outputs = [op(x), y1]
+ if y1.dtype == dtypes.float32:
+ loss = math_ops.reduce_sum(y1 * y1)
+ else:
+ loss = None
+ if loss is not None:
+ grad = g.gradient(loss, x1)
+ if grad is not None:
+ outputs.append(grad)
del output_dtypes[:]
output_dtypes.extend([t.dtype for t in outputs])
return outputs
@@ -656,17 +703,19 @@
x_shape = [2, 3, 4, 5, 6]
x = random_ops.random_uniform(x_shape)
for data_format in ("NCHW", "NHWC"):
- bias_dim = 2 if data_format == "NCHW" else -1
- bias_shape = x_shape[bias_dim]
- bias = random_ops.random_uniform([bias_shape])
+ with backprop.GradientTape(persistent=True) as g:
+ bias_dim = 2 if data_format == "NCHW" else -1
+ bias_shape = x_shape[bias_dim]
+ bias = random_ops.random_uniform([bias_shape])
+ g.watch(bias)
# pylint: disable=cell-var-from-loop
def loop_fn(i):
- a = array_ops.gather(x, i)
- y = nn.bias_add(a, bias, data_format=data_format)
- loss = math_ops.reduce_sum(y * y)
- return y, gradient_ops.gradients(loss, bias)
-
+ with g:
+ a = array_ops.gather(x, i)
+ y = nn.bias_add(a, bias, data_format=data_format)
+ loss = math_ops.reduce_sum(y * y)
+ return y, g.gradient(loss, bias)
# pylint: enable=cell-var-from-loop
self._test_loop_fn(
@@ -727,6 +776,7 @@
self._test_loop_fn(loop_fn, 2)
+@test_util.run_all_in_graph_and_eager_modes
class NNTest(PForTest):
def test_conv2d(self):
@@ -779,30 +829,60 @@
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
def test_avg_pool(self):
- x = random_ops.random_uniform([3, 2, 12, 12, 3])
- ksize = [1, 3, 3, 1]
+ with backprop.GradientTape(persistent=True) as g:
+ x = random_ops.random_uniform([3, 2, 12, 12, 3])
+ g.watch(x)
+ ksize = [1, 3, 3, 1]
def loop_fn(i):
- x1 = array_ops.gather(x, i)
- output = nn.avg_pool(
- x1, ksize, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
- loss = nn.l2_loss(output)
- return output, gradient_ops.gradients(loss, x1)
+ with g:
+ x1 = array_ops.gather(x, i)
+ output = nn.avg_pool(
+ x1, ksize, strides=[1, 2, 2, 1], padding="VALID",
+ data_format="NHWC")
+ loss = nn.l2_loss(output)
+ return output, g.gradient(loss, x1)
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
def test_max_pool(self):
- x = random_ops.random_uniform([3, 2, 12, 12, 3])
- ksize = [1, 3, 3, 1]
+ with backprop.GradientTape(persistent=True) as g:
+ x = random_ops.random_uniform([3, 2, 12, 12, 3])
+ g.watch(x)
+ ksize = [1, 3, 3, 1]
+ strides = [1, 2, 2, 1]
def loop_fn(i):
- x1 = array_ops.gather(x, i)
- output = nn.max_pool(
- x1, ksize, strides=[1, 2, 2, 1], padding="VALID", data_format="NHWC")
- loss = nn.l2_loss(output)
- ones = array_ops.ones_like(output)
- grad = gradient_ops.gradients(loss, x1, grad_ys=ones)
- grad_grad = gradient_ops.gradients(grad, ones)
+ with g:
+ x1 = array_ops.gather(x, i)
+ output = nn.max_pool(
+ x1, ksize, strides=strides, padding="VALID", data_format="NHWC")
+ loss = nn.l2_loss(output)
+ ones = array_ops.ones_like(output)
+ g.watch(ones)
+ grad = g.gradient(loss, x1, output_gradients=ones)
+ grad_grad = g.gradient(grad, ones)
+ return output, grad, grad_grad
+
+ self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
+
+ def test_max_pool3d(self):
+ with backprop.GradientTape(persistent=True) as g:
+ x = random_ops.random_uniform([3, 3, 2, 12, 12, 3])
+ g.watch(x)
+ ksize = [1, 1, 3, 3, 1]
+ strides = [1, 1, 2, 2, 1]
+
+ def loop_fn(i):
+ with g:
+ x1 = array_ops.gather(x, i)
+ output = nn.max_pool3d(
+ x1, ksize, strides=strides, padding="VALID", data_format="NDHWC")
+ loss = nn.l2_loss(output)
+ ones = array_ops.ones_like(output)
+ g.watch(ones)
+ grad = g.gradient(loss, x1, output_gradients=ones)
+ grad_grad = g.gradient(grad, ones)
return output, grad, grad_grad
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
@@ -813,36 +893,41 @@
data_formats.append("NCHW")
for is_training in (True, False):
for data_format in data_formats:
- if data_format == "NCHW":
- x = random_ops.random_uniform([3, 1, 2, 5, 5])
- else:
- x = random_ops.random_uniform([3, 1, 5, 5, 2])
- scale = random_ops.random_uniform([2])
- offset = random_ops.random_uniform([2])
- mean = None if is_training else random_ops.random_uniform([2])
- variance = None if is_training else random_ops.random_uniform([2])
+ with backprop.GradientTape(persistent=True) as g:
+ if data_format == "NCHW":
+ x = random_ops.random_uniform([3, 1, 2, 5, 5])
+ else:
+ x = random_ops.random_uniform([3, 1, 5, 5, 2])
+ g.watch(x)
+ scale = random_ops.random_uniform([2])
+ g.watch(scale)
+ offset = random_ops.random_uniform([2])
+ g.watch(offset)
+ mean = None if is_training else random_ops.random_uniform([2])
+ variance = None if is_training else random_ops.random_uniform([2])
# pylint: disable=cell-var-from-loop
def loop_fn(i):
- x1 = array_ops.gather(x, i)
- outputs = nn.fused_batch_norm(
- x1,
- scale,
- offset,
- mean=mean,
- variance=variance,
- epsilon=0.01,
- data_format=data_format,
- is_training=is_training)
- outputs = list(outputs)
- # We only test the first value of outputs when is_training is False.
- # It looks like CPU and GPU have different outputs for batch_mean and
- # batch_variance for this case.
- if not is_training:
- outputs[1] = constant_op.constant(0.)
- outputs[2] = constant_op.constant(0.)
- loss = nn.l2_loss(outputs[0])
- gradients = gradient_ops.gradients(loss, [x1, scale, offset])
+ with g:
+ x1 = array_ops.gather(x, i)
+ outputs = nn.fused_batch_norm(
+ x1,
+ scale,
+ offset,
+ mean=mean,
+ variance=variance,
+ epsilon=0.01,
+ data_format=data_format,
+ is_training=is_training)
+ outputs = list(outputs)
+ # We only test the first value of outputs when is_training is False.
+ # It looks like CPU and GPU have different outputs for batch_mean
+ # and batch_variance for this case.
+ if not is_training:
+ outputs[1] = constant_op.constant(0.)
+ outputs[2] = constant_op.constant(0.)
+ loss = nn.l2_loss(outputs[0])
+ gradients = g.gradient(loss, [x1, scale, offset])
return outputs + gradients
# pylint: enable=cell-var-from-loop
@@ -850,16 +935,20 @@
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 6)
def test_softmax_cross_entropy_with_logits(self):
- logits = random_ops.random_uniform([3, 2, 4])
- labels = random_ops.random_uniform([3, 2, 4])
- labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True)
+ with backprop.GradientTape(persistent=True) as g:
+ logits = random_ops.random_uniform([3, 2, 4])
+ g.watch(logits)
+ labels = random_ops.random_uniform([3, 2, 4])
+ labels /= math_ops.reduce_sum(labels, axis=[2], keepdims=True)
def loop_fn(i):
- logits_i = array_ops.gather(logits, i)
- labels_i = array_ops.gather(labels, i)
- loss = nn.softmax_cross_entropy_with_logits(
- labels=labels_i, logits=logits_i)
- return loss, gradient_ops.gradients(math_ops.reduce_sum(loss), logits_i)
+ with g:
+ logits_i = array_ops.gather(logits, i)
+ labels_i = array_ops.gather(labels, i)
+ loss = nn.softmax_cross_entropy_with_logits(
+ labels=labels_i, logits=logits_i)
+ total_loss = math_ops.reduce_sum(loss)
+ return loss, g.gradient(total_loss, logits_i)
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
@@ -1044,7 +1133,7 @@
# y = x * x. Hence dy/dx = 2 * x.
actual_grad = 2.0 * x
with session.Session() as sess:
- actual_grad, computed_grad = sess.run([t1, actual_grad])
+ actual_grad, computed_grad = self.evaluate([t1, actual_grad])
self.assertAllClose(actual_grad, computed_grad)
@@ -1198,7 +1287,7 @@
expected_output = array_ops.transpose(expected_output, [1, 0])
with session.Session() as sess:
- out, expected = sess.run([out, expected_output])
+ out, expected = self.evaluate([out, expected_output])
self.assertAllClose(expected, out)
def test_tensor_array_as_loop_variable(self):
@@ -1278,13 +1367,12 @@
pfor_out, pfor_out_grad = pfor_control_flow_ops.pfor(loop_fn, 4)
# Note that tf.while_loop does not work in the setup above. So we manually
# construct the equivalent computation of the above loops here.
- real_out = math_ops.reduce_sum(inp, reduction_indices=[0])
- real_out = math_ops.reduce_prod(real_out, reduction_indices=[1])
+ real_out = math_ops.reduce_sum(inp, axis=[0])
+ real_out = math_ops.reduce_prod(real_out, axis=[1])
# Note that gradients of real_out will accumulate the gradients across the
# output value. Hence we do the same aggregation on pfor_out_grad.
real_out_grad = gradient_ops.gradients(real_out, inp)[0]
- sum_pfor_out_grad = math_ops.reduce_sum(
- pfor_out_grad, reduction_indices=[0])
+ sum_pfor_out_grad = math_ops.reduce_sum(pfor_out_grad, axis=[0])
with session.Session() as sess:
v1, v2, v1_grad, v2_grad = sess.run(
@@ -1387,7 +1475,7 @@
sess = session.Session()
with sess:
init = variables.global_variables_initializer()
- sess.run(init)
+ self.evaluate(init)
run_fn = sess.make_callable(targets)
run_fn() # Warm up
begin = time.time()
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
index 1f026b3..3ba1bde 100644
--- a/tensorflow/python/ops/parallel_for/gradients.py
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -25,7 +25,7 @@
from tensorflow.python.util import nest
-def jacobian(output, inputs, use_pfor=True):
+def jacobian(output, inputs, use_pfor=True, parallel_iterations=None):
"""Computes jacobian of `output` w.r.t. `inputs`.
Args:
@@ -33,6 +33,8 @@
inputs: A tensor or a nested structure of tensor objects.
use_pfor: If true, uses pfor for computing the jacobian. Else uses
tf.while_loop.
+ parallel_iterations: A knob to control how many iterations and dispatched in
+ parallel. This knob can be used to control the total memory usage.
Returns:
A tensor or a nested strucutre of tensors with the same structure as
@@ -56,10 +58,14 @@
output_size = array_ops.shape(output)[0]
if use_pfor:
- pfor_outputs = control_flow_ops.pfor(loop_fn, output_size)
+ pfor_outputs = control_flow_ops.pfor(
+ loop_fn, output_size, parallel_iterations=parallel_iterations)
else:
pfor_outputs = control_flow_ops.for_loop(
- loop_fn, [output.dtype] * len(flat_inputs), output_size)
+ loop_fn,
+ [output.dtype] * len(flat_inputs),
+ output_size,
+ parallel_iterations=parallel_iterations)
for i, out in enumerate(pfor_outputs):
if out is not None:
@@ -72,7 +78,7 @@
return nest.pack_sequence_as(inputs, pfor_outputs)
-def batch_jacobian(output, inp, use_pfor=True):
+def batch_jacobian(output, inp, use_pfor=True, parallel_iterations=None):
"""Computes and stacks jacobians of `output[i,...]` w.r.t. `input[i,...]`.
e.g.
@@ -87,6 +93,8 @@
inp: A tensor with shape [b, x1, ..., x_m]
use_pfor: If true, uses pfor for computing the Jacobian. Else uses a
tf.while_loop.
+ parallel_iterations: A knob to control how many iterations and dispatched in
+ parallel. This knob can be used to control the total memory usage.
Returns:
A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]`
@@ -118,10 +126,13 @@
return gradient_ops.gradients(y, inp)[0]
if use_pfor:
- pfor_output = control_flow_ops.pfor(loop_fn, output_row_size)
+ pfor_output = control_flow_ops.pfor(loop_fn, output_row_size,
+ parallel_iterations=parallel_iterations)
else:
- pfor_output = control_flow_ops.for_loop(loop_fn, output.dtype,
- output_row_size)
+ pfor_output = control_flow_ops.for_loop(
+ loop_fn, output.dtype,
+ output_row_size,
+ parallel_iterations=parallel_iterations)
if pfor_output is None:
return None
pfor_output = array_ops.reshape(pfor_output,
diff --git a/tensorflow/python/ops/parallel_for/gradients_test.py b/tensorflow/python/ops/parallel_for/gradients_test.py
index b2be24e..4342833 100644
--- a/tensorflow/python/ops/parallel_for/gradients_test.py
+++ b/tensorflow/python/ops/parallel_for/gradients_test.py
@@ -416,6 +416,12 @@
self.assertAllClose(ans, pfor_value)
self.assertAllClose(ans, while_value)
+ def test_jacobian_parallel_iterations(self):
+ x = constant_op.constant([[1., 2], [3, 4]])
+ y = math_ops.matmul(x, x)
+ self.assertAllClose(gradients.jacobian(y, x, parallel_iterations=2),
+ gradients.jacobian(y, x, parallel_iterations=3))
+
def test_batch_jacobian_bad_shapes(self):
x = random_ops.random_uniform([2, 2])
y = random_ops.random_uniform([3, 2])
@@ -459,6 +465,13 @@
self.assertAllClose(ans, pfor_value)
self.assertAllClose(ans, while_value)
+ def test_batch_jacobian_parallel_iterations(self):
+ x = constant_op.constant([[1., 2], [3, 4]])
+ w = constant_op.constant([[1., 2, 3, 4], [5, 6, 7, 8]])
+ y = math_ops.matmul(x, w)
+ self.assertAllClose(gradients.batch_jacobian(y, x, parallel_iterations=2),
+ gradients.batch_jacobian(y, x, parallel_iterations=3))
+
def test_fc_batch_jacobian(self):
pfor_jacobian, while_jacobian = create_fc_batch_jacobian(8, 4, 2)
self.run_and_assert_equal(pfor_jacobian, while_jacobian)
@@ -471,7 +484,7 @@
pfor_jacobian, while_gradients = create_dynamic_lstm_batch_jacobian(8, 4, 3)
with session.Session() as sess:
init = variables.global_variables_initializer()
- sess.run(init)
+ self.evaluate(init)
pfor = self.evaluate(pfor_jacobian)
for i in range(4):
while_i = sess.run(while_gradients[i])
@@ -547,11 +560,11 @@
sess = session.Session()
with sess:
init = variables.global_variables_initializer()
- sess.run(init)
- sess.run(targets)
+ self.evaluate(init)
+ self.evaluate(targets)
begin = time.time()
for _ in range(iters):
- sess.run(targets)
+ self.evaluate(targets)
end = time.time()
avg_time_ms = 1000 * (end - begin) / iters
self.report_benchmark(iters=iters, wall_time=avg_time_ms, name=name)
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index e6f140a..a22c112 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -1152,9 +1152,8 @@
continue
converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs]
- some_input_converted = any(
- [self._was_converted(x) for x in y_op.inputs])
- some_input_stacked = any([x.is_stacked for x in converted_inputs])
+ some_input_converted = any(self._was_converted(x) for x in y_op.inputs)
+ some_input_stacked = any(x.is_stacked for x in converted_inputs)
converted_control_ops = set()
some_control_input_converted = False
@@ -1198,7 +1197,7 @@
# All inputs are unstacked or uncoverted but some control inputs are
# converted.
# TODO(rachelim): Handle the case where some inputs are sparsely
- # stacked (i.e. any([x.is_sparse_stacked for x in converted_inputs]))
+ # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs))
new_op = _create_op(y_op.type, [x.t for x in converted_inputs],
[x.dtype for x in y_op.outputs],
y_op.node_def.attr)
@@ -1303,7 +1302,10 @@
@RegisterPForWithArgs("Conv2D", dims=[0])
@RegisterPForWithArgs("AvgPool", dims=[0])
@RegisterPForWithArgs("MaxPool", dims=[0])
+@RegisterPForWithArgs("MaxPool3D", dims=[0])
+@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2])
@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2])
+@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2])
@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2])
@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1])
def _convert_flatten_batch(pfor_input, op_type, dims):
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index 484caf0..a84af6c 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -363,7 +363,7 @@
return features
-@tf_export("io.parse_example", v1=["io.parse_example", "parse_example"])
+@tf_export(v1=["io.parse_example", "parse_example"])
def parse_example(serialized, features, name=None, example_names=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
@@ -577,6 +577,223 @@
Raises:
ValueError: if any feature is invalid.
"""
+ return parse_example_v2(serialized, features, example_names, name)
+
+
+@tf_export("io.parse_example", v1=[])
+def parse_example_v2(serialized, features, example_names=None, name=None):
+ # pylint: disable=line-too-long
+ """Parses `Example` protos into a `dict` of tensors.
+
+ Parses a number of serialized [`Example`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
+ protos given in `serialized`. We refer to `serialized` as a batch with
+ `batch_size` many entries of individual `Example` protos.
+
+ `example_names` may contain descriptive names for the corresponding serialized
+ protos. These may be useful for debugging purposes, but they have no effect on
+ the output. If not `None`, `example_names` must be the same length as
+ `serialized`.
+
+ This op parses serialized examples into a dictionary mapping keys to `Tensor`
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`.
+
+ Each `VarLenFeature` maps to a `SparseTensor` of the specified type
+ representing a ragged matrix. Its indices are `[batch, index]` where `batch`
+ identifies the example in `serialized`, and `index` is the value's index in
+ the list of values associated with that feature and example.
+
+ Each `SparseFeature` maps to a `SparseTensor` of the specified type
+ representing a Tensor of `dense_shape` `[batch_size] + SparseFeature.size`.
+ Its `values` come from the feature in the examples with key `value_key`.
+ A `values[i]` comes from a position `k` in the feature of an example at batch
+ entry `batch`. This positional information is recorded in `indices[i]` as
+ `[batch, index_0, index_1, ...]` where `index_j` is the `k-th` value of
+ the feature in the example at with key `SparseFeature.index_key[j]`.
+ In other words, we split the indices (except the first index indicating the
+ batch entry) of a `SparseTensor` by dimension into different features of the
+ `Example`. Due to its complexity a `VarLenFeature` should be preferred over a
+ `SparseFeature` whenever possible.
+
+ Each `FixedLenFeature` `df` maps to a `Tensor` of the specified type (or
+ `tf.float32` if not specified) and shape `(serialized.size(),) + df.shape`.
+
+ `FixedLenFeature` entries with a `default_value` are optional. With no default
+ value, we will fail if that `Feature` is missing from any example in
+ `serialized`.
+
+ Each `FixedLenSequenceFeature` `df` maps to a `Tensor` of the specified type
+ (or `tf.float32` if not specified) and shape
+ `(serialized.size(), None) + df.shape`.
+ All examples in `serialized` will be padded with `default_value` along the
+ second dimension.
+
+ Examples:
+
+ For example, if one expects a `tf.float32` `VarLenFeature` `ft` and three
+ serialized `Example`s are provided:
+
+ ```
+ serialized = [
+ features
+ { feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } },
+ features
+ { feature []},
+ features
+ { feature { key: "ft" value { float_list { value: [3.0] } } }
+ ]
+ ```
+
+ then the output will look like:
+
+ ```python
+ {"ft": SparseTensor(indices=[[0, 0], [0, 1], [2, 0]],
+ values=[1.0, 2.0, 3.0],
+ dense_shape=(3, 2)) }
+ ```
+
+ If instead a `FixedLenSequenceFeature` with `default_value = -1.0` and
+ `shape=[]` is used then the output will look like:
+
+ ```python
+ {"ft": [[1.0, 2.0], [3.0, -1.0]]}
+ ```
+
+ Given two `Example` input protos in `serialized`:
+
+ ```
+ [
+ features {
+ feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } }
+ feature { key: "gps" value { float_list { value: [] } } }
+ },
+ features {
+ feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } }
+ feature { key: "dank" value { int64_list { value: [ 42 ] } } }
+ feature { key: "gps" value { } }
+ }
+ ]
+ ```
+
+ And arguments
+
+ ```
+ example_names: ["input0", "input1"],
+ features: {
+ "kw": VarLenFeature(tf.string),
+ "dank": VarLenFeature(tf.int64),
+ "gps": VarLenFeature(tf.float32),
+ }
+ ```
+
+ Then the output is a dictionary:
+
+ ```python
+ {
+ "kw": SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0]],
+ values=["knit", "big", "emmy"]
+ dense_shape=[2, 2]),
+ "dank": SparseTensor(
+ indices=[[1, 0]],
+ values=[42],
+ dense_shape=[2, 1]),
+ "gps": SparseTensor(
+ indices=[],
+ values=[],
+ dense_shape=[2, 0]),
+ }
+ ```
+
+ For dense results in two serialized `Example`s:
+
+ ```
+ [
+ features {
+ feature { key: "age" value { int64_list { value: [ 0 ] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ },
+ features {
+ feature { key: "age" value { int64_list { value: [] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ }
+ ]
+ ```
+
+ We can use arguments:
+
+ ```
+ example_names: ["input0", "input1"],
+ features: {
+ "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
+ "gender": FixedLenFeature([], dtype=tf.string),
+ }
+ ```
+
+ And the expected output is:
+
+ ```python
+ {
+ "age": [[0], [-1]],
+ "gender": [["f"], ["f"]],
+ }
+ ```
+
+ An alternative to `VarLenFeature` to obtain a `SparseTensor` is
+ `SparseFeature`. For example, given two `Example` input protos in
+ `serialized`:
+
+ ```
+ [
+ features {
+ feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
+ feature { key: "ix" value { int64_list { value: [ 3, 20 ] } } }
+ },
+ features {
+ feature { key: "val" value { float_list { value: [ 0.0 ] } } }
+ feature { key: "ix" value { int64_list { value: [ 42 ] } } }
+ }
+ ]
+ ```
+
+ And arguments
+
+ ```
+ example_names: ["input0", "input1"],
+ features: {
+ "sparse": SparseFeature(
+ index_key="ix", value_key="val", dtype=tf.float32, size=100),
+ }
+ ```
+
+ Then the output is a dictionary:
+
+ ```python
+ {
+ "sparse": SparseTensor(
+ indices=[[0, 3], [0, 20], [1, 42]],
+ values=[0.5, -1.0, 0.0]
+ dense_shape=[2, 100]),
+ }
+ ```
+
+ Args:
+ serialized: A vector (1-D Tensor) of strings, a batch of binary
+ serialized `Example` protos.
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
+ example_names: A vector (1-D Tensor) of strings (optional), the names of
+ the serialized protos in the batch.
+ name: A name for this operation (optional).
+
+ Returns:
+ A `dict` mapping feature keys to `Tensor` and `SparseTensor` values.
+
+ Raises:
+ ValueError: if any feature is invalid.
+ """
if not features:
raise ValueError("Missing: features was %s." % features)
features = _prepend_none_dimension(features)
@@ -764,8 +981,7 @@
dense_shapes_as_proto, dense_shapes)
-@tf_export("io.parse_single_example",
- v1=["io.parse_single_example", "parse_single_example"])
+@tf_export(v1=["io.parse_single_example", "parse_single_example"])
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
@@ -798,6 +1014,48 @@
Raises:
ValueError: if any feature is invalid.
"""
+ return parse_single_example_v2_unoptimized(
+ serialized, features, example_names, name
+ )
+
+
+# TODO(b/70890287): Combine the implementation of this op and
+# `parse_single_example_v2()` after 1/10/2018.
+@tf_export("io.parse_single_example", v1=[])
+def parse_single_example_v2_unoptimized(
+ serialized, features, example_names=None, name=None
+ ):
+ """Parses a single `Example` proto.
+
+ Similar to `parse_example`, except:
+
+ For dense tensors, the returned `Tensor` is identical to the output of
+ `parse_example`, except there is no batch dimension, the output shape is the
+ same as the shape given in `dense_shape`.
+
+ For `SparseTensor`s, the first (batch) column of the indices matrix is removed
+ (the indices matrix is a column vector), the values vector is unchanged, and
+ the first (`batch_size`) entry of the shape vector is removed (it is now a
+ single element vector).
+
+ One might see performance advantages by batching `Example` protos with
+ `parse_example` instead of using this function directly.
+
+ Args:
+ serialized: A scalar string Tensor, a single serialized Example.
+ See `_parse_single_example_raw` documentation for more details.
+ features: A `dict` mapping feature keys to `FixedLenFeature` or
+ `VarLenFeature` values.
+ example_names: (Optional) A scalar string Tensor, the associated name.
+ See `_parse_single_example_raw` documentation for more details.
+ name: A name for this operation (optional).
+
+ Returns:
+ A `dict` mapping feature keys to `Tensor` and `SparseTensor` values.
+
+ Raises:
+ ValueError: if any feature is invalid.
+ """
if not features:
raise ValueError("Missing features.")
if example_names is None:
@@ -1570,7 +1828,7 @@
# Swap `name` and `na_value` for backward compatibility.
-@tf_export("io.decode_csv", v1=["io.decode_csv", "decode_csv"])
+@tf_export(v1=["io.decode_csv", "decode_csv"])
@deprecation.deprecated_endpoints("decode_csv")
def decode_csv(records,
record_defaults,
@@ -1612,6 +1870,54 @@
Raises:
ValueError: If any of the arguments is malformed.
"""
+ return decode_csv_v2(
+ records, record_defaults,
+ field_delim, use_quote_delim,
+ na_value, select_cols, name
+ )
+
+
+@tf_export("io.decode_csv", v1=[])
+def decode_csv_v2(records,
+ record_defaults,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ select_cols=None,
+ name=None):
+ """Convert CSV records to tensors. Each column maps to one tensor.
+
+ RFC 4180 format is expected for the CSV records.
+ (https://tools.ietf.org/html/rfc4180)
+ Note that we allow leading and trailing spaces with int or float field.
+
+ Args:
+ records: A `Tensor` of type `string`.
+ Each string is a record/row in the csv and all records should have
+ the same format.
+ record_defaults: A list of `Tensor` objects with specific types.
+ Acceptable types are `float32`, `float64`, `int32`, `int64`, `string`.
+ One tensor per column of the input record, with either a
+ scalar default value for that column or an empty vector if the column is
+ required.
+ field_delim: An optional `string`. Defaults to `","`.
+ char delimiter to separate fields in a record.
+ use_quote_delim: An optional `bool`. Defaults to `True`.
+ If false, treats double quotation marks as regular
+ characters inside of the string fields (ignoring RFC 4180, Section 2,
+ Bullet 5).
+ na_value: Additional string to recognize as NA/NaN.
+ select_cols: Optional sorted list of column indices to select. If specified,
+ only this subset of columns will be parsed and returned.
+ name: A name for the operation (optional).
+
+ Returns:
+ A list of `Tensor` objects. Has the same type as `record_defaults`.
+ Each tensor will have the same shape as records.
+
+ Raises:
+ ValueError: If any of the arguments is malformed.
+ """
if select_cols is not None and any(select_cols[i] >= select_cols[i + 1]
for i in range(len(select_cols) - 1)):
raise ValueError("select_cols is not strictly increasing.")
diff --git a/tensorflow/python/ops/partitioned_variables.py b/tensorflow/python/ops/partitioned_variables.py
index 7743b63..c1084c2 100644
--- a/tensorflow/python/ops/partitioned_variables.py
+++ b/tensorflow/python/ops/partitioned_variables.py
@@ -57,7 +57,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -68,7 +68,7 @@
]
-@tf_export("variable_axis_size_partitioner")
+@tf_export(v1=["variable_axis_size_partitioner"])
def variable_axis_size_partitioner(
max_shard_bytes, axis=0, bytes_per_string_element=16, max_shards=None):
"""Get a partitioner for VariableScope to keep shards below `max_shard_bytes`.
@@ -96,7 +96,7 @@
Returns:
A partition function usable as the `partitioner` argument to
- `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
+ `variable_scope` and `get_variable`.
Raises:
ValueError: If any of the byte counts are non-positive.
@@ -154,7 +154,7 @@
return _partitioner
-@tf_export("min_max_variable_partitioner")
+@tf_export(v1=["min_max_variable_partitioner"])
def min_max_variable_partitioner(max_partitions=1, axis=0,
min_slice_size=256 << 10,
bytes_per_string_element=16):
@@ -175,7 +175,7 @@
Returns:
A partition function usable as the `partitioner` argument to
- `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
+ `variable_scope` and `get_variable`.
"""
def _partitioner(shape, dtype):
@@ -218,7 +218,7 @@
return _partitioner
-@tf_export("fixed_size_partitioner")
+@tf_export(v1=["fixed_size_partitioner"])
def fixed_size_partitioner(num_shards, axis=0):
"""Partitioner to specify a fixed number of shards along given axis.
@@ -228,7 +228,7 @@
Returns:
A partition function usable as the `partitioner` argument to
- `variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
+ `variable_scope` and `get_variable`.
"""
def _partitioner(shape, **unused_args):
partitions_list = [1] * len(shape)
@@ -237,7 +237,10 @@
return _partitioner
-@tf_export("create_partitioned_variables")
+@tf_export(v1=["create_partitioned_variables"])
+@deprecation.deprecated(
+ date=None,
+ instructions="Use tf.get_variable with a partitioner set.")
def create_partitioned_variables(
shape, slicing, initializer, dtype=dtypes.float32,
trainable=True, collections=None, name=None, reuse=None):
@@ -282,11 +285,6 @@
Raises:
ValueError: If any of the arguments is malformed.
"""
- logging.warn(
- "create_partitioned_variables is deprecated. Use "
- "tf.get_variable with a partitioner set, or "
- "tf.get_partitioned_variable_list, instead.")
-
if len(shape) != len(slicing):
raise ValueError("The 'shape' and 'slicing' of a partitioned Variable "
"must have the length: shape: %s, slicing: %s" %
diff --git a/tensorflow/python/ops/ragged/ragged_array_ops.py b/tensorflow/python/ops/ragged/ragged_array_ops.py
index 425f395..815f48a 100644
--- a/tensorflow/python/ops/ragged/ragged_array_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_array_ops.py
@@ -1172,6 +1172,17 @@
ragged_rank = rt_input.ragged_rank
nested_splits = rt_input.nested_row_splits
+ # projected_splits[src_axis, dst_axis] contains the split points that divide
+ # the rows from src_axis in the list of dst_axis values. E.g.,
+ # projected_splits[i, i] = nested_splits[i], and
+ # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]).
+ projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)]
+ for src_axis in range(ragged_rank):
+ for dst_axis in range(src_axis + 1, ragged_rank - 1):
+ projected_splits[src_axis][dst_axis] = array_ops.gather(
+ nested_splits[dst_axis],
+ projected_splits[src_axis][dst_axis - 1])
+
# For each ragged dimension: nested_splits[axis] -> result_splits[axis].
result_splits = []
for axis in range(ragged_rank):
@@ -1188,7 +1199,7 @@
repeats = 1
for d in range(axis - 1, -1, -1):
if const_multiples is None or const_multiples[d + 1] != 1:
- splits = nested_splits[d] * repeats
+ splits = projected_splits[d][axis - 1] * repeats
output_lengths = _repeat_ranges(output_lengths, splits,
multiples[d + 1])
repeats *= multiples[d + 1]
diff --git a/tensorflow/python/ops/ragged/ragged_const_op_test.py b/tensorflow/python/ops/ragged/ragged_const_op_test.py
index 66c3947..9c3b2ac 100644
--- a/tensorflow/python/ops/ragged/ragged_const_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_const_op_test.py
@@ -238,8 +238,8 @@
dict(
pylist=[1, 2, 3],
inner_shape=(1, 1),
- exception=ValueError,
- message='Too many elements provided.'),
+ exception=TypeError,
+ message='Expected Tensor\'s shape'),
dict(
pylist=[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
inner_shape=(2, 2),
diff --git a/tensorflow/python/ops/ragged/ragged_conversion_ops.py b/tensorflow/python/ops/ragged/ragged_conversion_ops.py
index 0385be0..83212e4 100644
--- a/tensorflow/python/ops/ragged/ragged_conversion_ops.py
+++ b/tensorflow/python/ops/ragged/ragged_conversion_ops.py
@@ -361,9 +361,14 @@
st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor(
st_input, name='rt_input')
- if (st_input.dense_shape.shape.ndims != 2 and
- st_input.indices.shape.ndims is None or
- st_input.indices.shape.dims[1].value != 2):
+ static_rank_from_dense_shape = (
+ None if st_input.dense_shape.shape.ndims is None
+ else st_input.dense_shape.shape.dims[0].value)
+ static_rank_from_indices = (
+ None if st_input.indices.shape.ndims is None
+ else st_input.indices.shape.dims[1].value)
+
+ if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2:
raise ValueError('rank(st_input) must be 2')
with ops.control_dependencies(
diff --git a/tensorflow/python/ops/ragged/ragged_from_sparse_op_test.py b/tensorflow/python/ops/ragged/ragged_from_sparse_op_test.py
index ff19dde..77418ff 100644
--- a/tensorflow/python/ops/ragged/ragged_from_sparse_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_from_sparse_op_test.py
@@ -64,6 +64,20 @@
self.assertRaisesRegexp(ValueError, r'rank\(st_input\) must be 2',
ragged.from_sparse, st3)
+ def testGoodPartialSparseTensorRank(self):
+ st1 = sparse_tensor.SparseTensor(
+ indices=[[0, 0]],
+ values=[0],
+ dense_shape=array_ops.placeholder(dtypes.int64))
+ st2 = sparse_tensor.SparseTensor(
+ indices=array_ops.placeholder(dtypes.int64),
+ values=[0],
+ dense_shape=[4, 3])
+
+ # Shouldn't throw ValueError
+ ragged.from_sparse(st1)
+ ragged.from_sparse(st2)
+
def testNonRaggedSparseTensor(self):
# "index_suffix" means the value of the innermost dimension of the index
# (i.e., indices[i][-1]).
diff --git a/tensorflow/python/ops/ragged/ragged_segment_op_test.py b/tensorflow/python/ops/ragged/ragged_segment_op_test.py
index 373a332..40a101b 100644
--- a/tensorflow/python/ops/ragged/ragged_segment_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_segment_op_test.py
@@ -118,8 +118,7 @@
combiner)
segmented = segment_op(rt, segment_ids, num_segments)
- with self.test_session():
- self.assertListEqual(segmented.eval().tolist(), expected)
+ self.assertListEqual(self.evaluate(segmented).tolist(), expected)
@parameterized.parameters(
(ragged.segment_sum, sum, [0, 0, 1, 1, 2, 2]),
@@ -155,9 +154,8 @@
combiner)
segmented = segment_op(rt, segment_ids, num_segments)
- with self.test_session():
- self.assertNestedListAmostEqual(
- self.evaluate(segmented).tolist(), expected, places=5)
+ self.assertNestedListAmostEqual(
+ self.evaluate(segmented).tolist(), expected, places=5)
def testRaggedRankTwo(self):
rt = ragged.constant([
@@ -172,16 +170,14 @@
[], # row 1
[[411, 412], [321, 322], [331]] # row 2
] # pyformat: disable
- with self.test_session():
- self.assertEqual(segmented1.eval().tolist(), expected1)
+ self.assertEqual(self.evaluate(segmented1).tolist(), expected1)
segment_ids2 = [1, 2, 1, 1]
segmented2 = ragged.segment_sum(rt, segment_ids2, 3)
expected2 = [[],
[[111+411, 112+412, 113, 114], [121+321, 322], [331]],
[]] # pyformat: disable
- with self.test_session():
- self.assertEqual(segmented2.eval().tolist(), expected2)
+ self.assertEqual(self.evaluate(segmented2).tolist(), expected2)
def testRaggedSegmentIds(self):
rt = ragged.constant([
@@ -195,8 +191,7 @@
expected = [[],
[111+321, 112+322, 113, 114],
[121+331+411, 412]] # pyformat: disable
- with self.test_session():
- self.assertEqual(segmented.eval().tolist(), expected)
+ self.assertEqual(self.evaluate(segmented).tolist(), expected)
def testShapeMismatchError1(self):
dt = constant_op.constant([1, 2, 3, 4, 5, 6])
@@ -226,7 +221,7 @@
array_ops.placeholder_with_default(segment_ids.values, None),
array_ops.placeholder_with_default(segment_ids.row_splits, None))
segmented2 = ragged.segment_sum(rt, segment_ids2, 3)
- with self.test_session():
+ with self.cached_session():
self.assertRaisesRegexp(
errors.InvalidArgumentError,
'segment_ids.shape must be a prefix of data.shape.*', segmented2.eval)
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_bounding_shape_op_test.py b/tensorflow/python/ops/ragged/ragged_tensor_bounding_shape_op_test.py
index a1c10af..befe30f 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor_bounding_shape_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor_bounding_shape_op_test.py
@@ -28,41 +28,39 @@
def testDocStringExample(self):
# This is the example from ragged.bounding_shape.__doc__.
rt = ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]])
- with self.test_session():
- self.assertEqual(ragged.bounding_shape(rt).eval().tolist(), [5, 4])
+ self.assertEqual(self.evaluate(ragged.bounding_shape(rt)).tolist(), [5, 4])
def test2DRaggedTensorWithOneRaggedDimension(self):
values = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
rt1 = ragged.from_row_splits(values, [0, 2, 5, 6, 6, 7])
rt2 = ragged.from_row_splits(values, [0, 7])
rt3 = ragged.from_row_splits(values, [0, 0, 7, 7])
- with self.test_session():
- self.assertEqual(ragged.bounding_shape(rt1).eval().tolist(), [5, 3])
- self.assertEqual(ragged.bounding_shape(rt2).eval().tolist(), [1, 7])
- self.assertEqual(ragged.bounding_shape(rt3).eval().tolist(), [3, 7])
+ self.assertEqual(self.evaluate(ragged.bounding_shape(rt1)).tolist(), [5, 3])
+ self.assertEqual(self.evaluate(ragged.bounding_shape(rt2)).tolist(), [1, 7])
+ self.assertEqual(self.evaluate(ragged.bounding_shape(rt3)).tolist(), [3, 7])
def test3DRaggedTensorWithOneRaggedDimension(self):
values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]
rt1 = ragged.from_row_splits(values, [0, 2, 5, 6, 6, 7])
rt2 = ragged.from_row_splits(values, [0, 7])
rt3 = ragged.from_row_splits(values, [0, 0, 7, 7])
- with self.test_session():
- self.assertEqual(ragged.bounding_shape(rt1).eval().tolist(), [5, 3, 2])
- self.assertEqual(ragged.bounding_shape(rt2).eval().tolist(), [1, 7, 2])
- self.assertEqual(ragged.bounding_shape(rt3).eval().tolist(), [3, 7, 2])
+ self.assertEqual(
+ self.evaluate(ragged.bounding_shape(rt1)).tolist(), [5, 3, 2])
+ self.assertEqual(
+ self.evaluate(ragged.bounding_shape(rt2)).tolist(), [1, 7, 2])
+ self.assertEqual(
+ self.evaluate(ragged.bounding_shape(rt3)).tolist(), [3, 7, 2])
def testNonRaggedTensor(self):
dt = [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
- with self.test_session():
- self.assertEqual(ragged.bounding_shape(dt).eval().tolist(), [4, 3])
+ self.assertEqual(self.evaluate(ragged.bounding_shape(dt)).tolist(), [4, 3])
def testExplicitAxisOptimizations(self):
rt = ragged.from_row_splits(b'a b c d e f g'.split(), [0, 2, 5, 6, 6, 7])
- with self.test_session():
- self.assertEqual(ragged.bounding_shape(rt, 0).eval().tolist(), 5)
- self.assertEqual(ragged.bounding_shape(rt, 1).eval().tolist(), 3)
- self.assertEqual(
- ragged.bounding_shape(rt, [1, 0]).eval().tolist(), [3, 5])
+ self.assertEqual(self.evaluate(ragged.bounding_shape(rt, 0)).tolist(), 5)
+ self.assertEqual(self.evaluate(ragged.bounding_shape(rt, 1)).tolist(), 3)
+ self.assertEqual(
+ self.evaluate(ragged.bounding_shape(rt, [1, 0])).tolist(), [3, 5])
if __name__ == '__main__':
diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py
index f66ca10..fa681c0 100644
--- a/tensorflow/python/ops/ragged/ragged_tensor_test.py
+++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py
@@ -118,9 +118,8 @@
# From section: "Component Tensors"
rt = ragged.from_row_splits(
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
- with self.test_session():
- self.assertEqual(rt.tolist(),
- [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
+ self.assertEqual(
+ self.evaluate(rt).tolist(), [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
del rt
# From section: "Alternative Row-Partitioning Schemes"
@@ -132,9 +131,8 @@
rt4 = ragged.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
rt5 = ragged.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
for rt in (rt1, rt2, rt3, rt4, rt5):
- with self.test_session():
- self.assertEqual(rt.tolist(),
- [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
+ self.assertEqual(
+ self.evaluate(rt).tolist(), [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
del rt1, rt2, rt3, rt4, rt5
# From section: "Multiple Ragged Dimensions"
@@ -142,28 +140,27 @@
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
outer_rt = ragged.from_row_splits(values=inner_rt, row_splits=[0, 3, 3, 5])
self.assertEqual(outer_rt.ragged_rank, 2)
- with self.test_session():
- self.assertEqual(outer_rt.tolist(),
- [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
+ self.assertEqual(
+ self.evaluate(outer_rt).tolist(),
+ [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
del inner_rt, outer_rt
# From section: "Multiple Ragged Dimensions"
rt = ragged.from_nested_row_splits(
inner_values=[3, 1, 4, 1, 5, 9, 2, 6],
nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8]))
- with self.test_session():
- self.assertEqual(rt.tolist(),
- [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
del rt
# From section: "Uniform Inner Dimensions"
rt = ragged.from_row_splits(
values=array_ops.ones([5, 3]), row_splits=[0, 2, 5])
- with self.test_session():
- self.assertEqual(
- rt.tolist(),
- [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
- self.assertEqual(rt.shape.as_list(), [2, None, 3])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
+ self.assertEqual(rt.shape.as_list(), [2, None, 3])
del rt
#=============================================================================
@@ -208,9 +205,9 @@
rt = ragged.RaggedTensor(
values=values, row_splits=row_splits, internal=True)
- with self.test_session():
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testRaggedTensorConstructionErrors(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -262,11 +259,11 @@
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
- with self.test_session():
- self.assertAllEqual(rt_value_rowids, value_rowids)
- self.assertEqual(rt_nrows.eval(), 5)
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertAllEqual(rt_value_rowids, value_rowids)
+ self.assertEqual(self.evaluate(rt_nrows), 5)
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
# nrows is not known at graph creation time.
@@ -285,11 +282,11 @@
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
- with self.test_session():
- self.assertAllEqual(rt_value_rowids, value_rowids)
- self.assertEqual(rt_nrows.eval(), 5)
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertAllEqual(rt_value_rowids, value_rowids)
+ self.assertEqual(self.evaluate(rt_nrows), 5)
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithExplicitNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -308,10 +305,9 @@
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertIs(rt_nrows, nrows) # cached_nrows
- with self.test_session():
- self.assertEqual(
- rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -330,11 +326,11 @@
self.assertIs(rt_values, values)
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertIs(rt_nrows, nrows) # cached_nrows
- with self.test_session():
- self.assertAllEqual(rt_value_rowids, value_rowids)
- self.assertAllEqual(rt_nrows, nrows)
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertAllEqual(rt_value_rowids, value_rowids)
+ self.assertAllEqual(rt_nrows, nrows)
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithEmptyValues(self):
rt = ragged.from_value_rowids([], [])
@@ -344,9 +340,8 @@
self.assertEqual(rt.ragged_rank, 1)
self.assertEqual(rt.values.shape.as_list(), [0])
self.assertEqual(ragged.value_rowids(rt).shape.as_list(), [0])
- with self.test_session():
- self.assertEqual(rt_nrows.eval().tolist(), 0)
- self.assertEqual(rt.tolist(), [])
+ self.assertEqual(self.evaluate(rt_nrows).tolist(), 0)
+ self.assertEqual(self.evaluate(rt).tolist(), [])
def testFromRowSplits(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -363,10 +358,10 @@
self.assertIs(rt_values, values)
self.assertIs(rt_row_splits, row_splits)
- with self.test_session():
- self.assertEqual(rt_nrows.eval(), 5)
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertEqual(self.evaluate(rt_nrows), 5)
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowSplitsWithEmptySplits(self):
err_msg = 'row_splits tensor may not be empty'
@@ -387,11 +382,11 @@
rt_nrows = ragged.nrows(rt)
self.assertIs(rt_values, values)
- with self.test_session():
- self.assertEqual(rt_nrows.eval(), 5)
- self.assertAllEqual(rt_row_starts, row_starts)
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertEqual(self.evaluate(rt_nrows), 5)
+ self.assertAllEqual(rt_row_starts, row_starts)
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLimits(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -407,11 +402,11 @@
rt_nrows = ragged.nrows(rt)
self.assertIs(rt_values, values)
- with self.test_session():
- self.assertEqual(rt_nrows.eval(), 5)
- self.assertAllEqual(rt_row_limits, row_limits)
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertEqual(self.evaluate(rt_nrows), 5)
+ self.assertAllEqual(rt_row_limits, row_limits)
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLengths(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -428,11 +423,11 @@
self.assertIs(rt_values, values)
self.assertIs(rt_row_lengths, row_lengths) # cached_nrows
- with self.test_session():
- self.assertEqual(rt_nrows.eval(), 5)
- self.assertAllEqual(rt_row_lengths, row_lengths)
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertEqual(self.evaluate(rt_nrows), 5)
+ self.assertAllEqual(rt_row_lengths, row_lengths)
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromNestedValueRowIdsWithDerivedNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -452,12 +447,11 @@
rt_values_value_rowids = ragged.value_rowids(rt_values)
self.assertIs(rt_values_values, values)
- with self.test_session():
- self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
- self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
- self.assertEqual(
- rt.tolist(),
- [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
+ self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
+ self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedValueRowIdsWithExplicitNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -483,14 +477,14 @@
rt_values_nrows = ragged.nrows(rt_values)
self.assertIs(rt_values_values, values)
- with self.test_session():
- self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
- self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
- self.assertAllEqual(rt_nrows, nrows[0])
- self.assertAllEqual(rt_values_nrows, nrows[1])
- self.assertEqual(rt.tolist(),
- [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
- [[b'f'], [b'g'], []], [], []])
+ self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
+ self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
+ self.assertAllEqual(rt_nrows, nrows[0])
+ self.assertAllEqual(rt_values_nrows, nrows[1])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g'], []], [],
+ []])
def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -535,10 +529,9 @@
self.assertIs(rt_values_values, inner_values)
self.assertIs(rt_row_splits, nested_row_splits[0])
self.assertIs(rt_values_row_splits, nested_row_splits[1])
- with self.test_session():
- self.assertEqual(
- rt.tolist(),
- [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedRowSplitsWithNonListInput(self):
with self.assertRaisesRegexp(TypeError,
@@ -603,24 +596,31 @@
rt2 = ragged.from_value_rowids(values, value_rowids)
for rt in [rt1, rt2]:
- with self.test_session():
- self.assertEqual(rt.tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
- self.assertEqual(rt.values.eval().tolist(),
- [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
- self.assertEqual(rt.values.shape.dims[0].value, 7)
- self.assertEqual(
- ragged.value_rowids(rt).eval().tolist(), [0, 0, 2, 2, 2, 3, 4])
- self.assertEqual(ragged.nrows(rt).eval().tolist(), 5)
- self.assertEqual(rt.row_splits.eval().tolist(), [0, 2, 2, 5, 6, 7])
- self.assertEqual(ragged.row_starts(rt).eval().tolist(), [0, 2, 2, 5, 6])
- self.assertEqual(ragged.row_limits(rt).eval().tolist(), [2, 2, 5, 6, 7])
- self.assertEqual(
- ragged.row_lengths(rt).eval().tolist(), [2, 0, 3, 1, 1])
- self.assertEqual(rt.inner_values.eval().tolist(),
- [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
- self.assertEqual([s.eval().tolist() for s in rt.nested_row_splits],
- [[0, 2, 2, 5, 6, 7]])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertEqual(
+ self.evaluate(rt.values).tolist(),
+ [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
+ self.assertEqual(rt.values.shape.dims[0].value, 7)
+ self.assertEqual(
+ self.evaluate(ragged.value_rowids(rt)).tolist(),
+ [0, 0, 2, 2, 2, 3, 4])
+ self.assertEqual(self.evaluate(ragged.nrows(rt)).tolist(), 5)
+ self.assertEqual(
+ self.evaluate(rt.row_splits).tolist(), [0, 2, 2, 5, 6, 7])
+ self.assertEqual(
+ self.evaluate(ragged.row_starts(rt)).tolist(), [0, 2, 2, 5, 6])
+ self.assertEqual(
+ self.evaluate(ragged.row_limits(rt)).tolist(), [2, 2, 5, 6, 7])
+ self.assertEqual(
+ self.evaluate(ragged.row_lengths(rt)).tolist(), [2, 0, 3, 1, 1])
+ self.assertEqual(
+ self.evaluate(rt.inner_values).tolist(),
+ [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
+ self.assertEqual(
+ [self.evaluate(s).tolist() for s in rt.nested_row_splits],
+ [[0, 2, 2, 5, 6, 7]])
def testRaggedTensorAccessors_3d_with_ragged_rank_1(self):
values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]
@@ -630,27 +630,32 @@
rt2 = ragged.from_value_rowids(values, value_rowids)
for rt in [rt1, rt2]:
- with self.test_session():
- self.assertEqual(rt.tolist(),
- [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]],
- [[10, 11]], [[12, 13]]])
- self.assertEqual(
- rt.values.eval().tolist(),
- [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
- self.assertEqual(rt.values.shape.dims[0].value, 7)
- self.assertEqual(
- ragged.value_rowids(rt).eval().tolist(), [0, 0, 2, 2, 2, 3, 4])
- self.assertEqual(ragged.nrows(rt).eval().tolist(), 5)
- self.assertEqual(rt.row_splits.eval().tolist(), [0, 2, 2, 5, 6, 7])
- self.assertEqual(ragged.row_starts(rt).eval().tolist(), [0, 2, 2, 5, 6])
- self.assertEqual(ragged.row_limits(rt).eval().tolist(), [2, 2, 5, 6, 7])
- self.assertEqual(
- ragged.row_lengths(rt).eval().tolist(), [2, 0, 3, 1, 1])
- self.assertEqual(
- rt.inner_values.eval().tolist(),
- [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
- self.assertEqual([s.eval().tolist() for s in rt.nested_row_splits],
- [[0, 2, 2, 5, 6, 7]])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], [[10, 11]],
+ [[12, 13]]])
+ self.assertEqual(
+ self.evaluate(rt.values).tolist(),
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
+ self.assertEqual(rt.values.shape.dims[0].value, 7)
+ self.assertEqual(
+ self.evaluate(ragged.value_rowids(rt)).tolist(),
+ [0, 0, 2, 2, 2, 3, 4])
+ self.assertEqual(self.evaluate(ragged.nrows(rt)).tolist(), 5)
+ self.assertEqual(
+ self.evaluate(rt.row_splits).tolist(), [0, 2, 2, 5, 6, 7])
+ self.assertEqual(
+ self.evaluate(ragged.row_starts(rt)).tolist(), [0, 2, 2, 5, 6])
+ self.assertEqual(
+ self.evaluate(ragged.row_limits(rt)).tolist(), [2, 2, 5, 6, 7])
+ self.assertEqual(
+ self.evaluate(ragged.row_lengths(rt)).tolist(), [2, 0, 3, 1, 1])
+ self.assertEqual(
+ self.evaluate(rt.inner_values).tolist(),
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
+ self.assertEqual(
+ [self.evaluate(s).tolist() for s in rt.nested_row_splits],
+ [[0, 2, 2, 5, 6, 7]])
def testRaggedTensorAccessors_3d_with_ragged_rank_2(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@@ -666,36 +671,39 @@
rt2 = ragged.from_nested_value_rowids(values, nested_value_rowids)
for rt in [rt1, rt2]:
- with self.test_session():
- self.assertEqual(
- rt.tolist(),
- [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
- self.assertEqual(rt.values.eval().tolist(),
- [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
- self.assertEqual(rt.values.shape.dims[0].value, 5)
- self.assertEqual(
- ragged.value_rowids(rt).eval().tolist(), [0, 0, 1, 3, 3])
- self.assertEqual(ragged.nrows(rt).eval().tolist(), 4)
- self.assertEqual(rt.row_splits.eval().tolist(), [0, 2, 3, 3, 5])
- self.assertEqual(ragged.row_starts(rt).eval().tolist(), [0, 2, 3, 3])
- self.assertEqual(ragged.row_limits(rt).eval().tolist(), [2, 3, 3, 5])
- self.assertEqual(ragged.row_lengths(rt).eval().tolist(), [2, 1, 0, 2])
- self.assertEqual(rt.inner_values.eval().tolist(),
- [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
- self.assertEqual([s.eval().tolist() for s in rt.nested_row_splits],
- [[0, 2, 3, 3, 5], [0, 2, 2, 5, 6, 7]])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
+ self.assertEqual(
+ self.evaluate(rt.values).tolist(),
+ [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
+ self.assertEqual(rt.values.shape.dims[0].value, 5)
+ self.assertEqual(
+ self.evaluate(ragged.value_rowids(rt)).tolist(), [0, 0, 1, 3, 3])
+ self.assertEqual(self.evaluate(ragged.nrows(rt)).tolist(), 4)
+ self.assertEqual(self.evaluate(rt.row_splits).tolist(), [0, 2, 3, 3, 5])
+ self.assertEqual(
+ self.evaluate(ragged.row_starts(rt)).tolist(), [0, 2, 3, 3])
+ self.assertEqual(
+ self.evaluate(ragged.row_limits(rt)).tolist(), [2, 3, 3, 5])
+ self.assertEqual(
+ self.evaluate(ragged.row_lengths(rt)).tolist(), [2, 1, 0, 2])
+ self.assertEqual(
+ self.evaluate(rt.inner_values).tolist(),
+ [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
+ self.assertEqual(
+ [self.evaluate(s).tolist() for s in rt.nested_row_splits],
+ [[0, 2, 3, 3, 5], [0, 2, 2, 5, 6, 7]])
def testNRowsWithTensorInput(self):
dt = constant_op.constant([[1, 2, 3], [4, 5, 6]])
nrows = ragged.nrows(dt)
- with self.test_session():
- self.assertEqual(nrows.eval(), 2)
+ self.assertEqual(self.evaluate(nrows), 2)
def testRowLengthsWithTensorInput(self):
dt = constant_op.constant([[1, 2, 3], [4, 5, 6]])
row_lengths = ragged.row_lengths(dt)
- with self.test_session():
- self.assertEqual(row_lengths.eval().tolist(), [3, 3])
+ self.assertEqual(self.evaluate(row_lengths).tolist(), [3, 3])
#=============================================================================
# RaggedTensor.shape
@@ -748,29 +756,27 @@
expected: The expected value of rt.__getitem__(slice_spec), as a python
list; or an exception class.
"""
- with self.test_session():
- tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
- tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
- value1 = rt.__getitem__(slice_spec).eval()
- value2 = rt.__getitem__(tensor_slice_spec1).eval()
- value3 = rt.__getitem__(tensor_slice_spec2).eval()
- if hasattr(value1, 'tolist'):
- value1 = value1.tolist()
- if hasattr(value2, 'tolist'):
- value2 = value2.tolist()
- if hasattr(value3, 'tolist'):
- value3 = value3.tolist()
- self.assertEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
- self.assertEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
- self.assertEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
+ tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
+ tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
+ value1 = self.evaluate(rt.__getitem__(slice_spec))
+ value2 = self.evaluate(rt.__getitem__(tensor_slice_spec1))
+ value3 = self.evaluate(rt.__getitem__(tensor_slice_spec2))
+ if hasattr(value1, 'tolist'):
+ value1 = value1.tolist()
+ if hasattr(value2, 'tolist'):
+ value2 = value2.tolist()
+ if hasattr(value3, 'tolist'):
+ value3 = value3.tolist()
+ self.assertEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
+ self.assertEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
+ self.assertEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
def _TestGetItemException(self, rt, slice_spec, expected, message):
"""Helper function for testing RaggedTensor.__getitem__ exceptions."""
- with self.test_session():
- tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
- self.assertRaisesRegexp(expected, message, rt.__getitem__, slice_spec)
- self.assertRaisesRegexp(expected, message, rt.__getitem__,
- tensor_slice_spec1)
+ tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
+ self.assertRaisesRegexp(expected, message, rt.__getitem__, slice_spec)
+ self.assertRaisesRegexp(expected, message, rt.__getitem__,
+ tensor_slice_spec1)
@parameterized.parameters(
# Tests for rt[i]
@@ -842,8 +848,7 @@
rt = ragged.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES,
EXAMPLE_RAGGED_TENSOR_2D_SPLITS)
- with self.test_session():
- self.assertEqual(rt.tolist(), EXAMPLE_RAGGED_TENSOR_2D)
+ self.assertEqual(self.evaluate(rt).tolist(), EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItem(rt, slice_spec, expected)
# pylint: disable=invalid-slice-index
@@ -887,8 +892,7 @@
# if sys.version_info[0] == 3:
# message = 'must be str, not int'
- with self.test_session():
- self.assertEqual(rt.tolist(), EXAMPLE_RAGGED_TENSOR_2D)
+ self.assertEqual(self.evaluate(rt).tolist(), EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItemException(rt, slice_spec, expected, message)
@parameterized.parameters(
@@ -962,8 +966,7 @@
rt = ragged.from_nested_row_splits(
EXAMPLE_RAGGED_TENSOR_4D_VALUES,
[EXAMPLE_RAGGED_TENSOR_4D_SPLITS1, EXAMPLE_RAGGED_TENSOR_4D_SPLITS2])
- with self.test_session():
- self.assertEqual(rt.tolist(), EXAMPLE_RAGGED_TENSOR_4D)
+ self.assertEqual(self.evaluate(rt).tolist(), EXAMPLE_RAGGED_TENSOR_4D)
self._TestGetItem(rt, slice_spec, expected)
@parameterized.parameters(
@@ -985,8 +988,7 @@
rt = ragged.from_nested_row_splits(
EXAMPLE_RAGGED_TENSOR_4D_VALUES,
[EXAMPLE_RAGGED_TENSOR_4D_SPLITS1, EXAMPLE_RAGGED_TENSOR_4D_SPLITS2])
- with self.test_session():
- self.assertEqual(rt.tolist(), EXAMPLE_RAGGED_TENSOR_4D)
+ self.assertEqual(self.evaluate(rt).tolist(), EXAMPLE_RAGGED_TENSOR_4D)
self._TestGetItemException(rt, slice_spec, expected, message)
@parameterized.parameters(
@@ -1026,8 +1028,7 @@
EXAMPLE_RAGGED_TENSOR_2D_SPLITS, dtype=dtypes.int64)
splits = array_ops.placeholder_with_default(splits, None)
rt = ragged.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES, splits)
- with self.test_session():
- self.assertEqual(rt.tolist(), EXAMPLE_RAGGED_TENSOR_2D)
+ self.assertEqual(self.evaluate(rt).tolist(), EXAMPLE_RAGGED_TENSOR_2D)
self._TestGetItem(rt, slice_spec, expected)
@parameterized.parameters(
@@ -1047,43 +1048,43 @@
splits2 = [0, 2, 2, 3]
values = constant_op.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])
rt = ragged.from_nested_row_splits(values, [splits1, splits2])
- with self.test_session():
- rt_newaxis0 = rt[array_ops.newaxis]
- rt_newaxis1 = rt[:, array_ops.newaxis]
- rt_newaxis2 = rt[:, :, array_ops.newaxis]
- rt_newaxis3 = rt[:, :, :, array_ops.newaxis]
- rt_newaxis4 = rt[:, :, :, :, array_ops.newaxis]
+ rt_newaxis0 = rt[array_ops.newaxis]
+ rt_newaxis1 = rt[:, array_ops.newaxis]
+ rt_newaxis2 = rt[:, :, array_ops.newaxis]
+ rt_newaxis3 = rt[:, :, :, array_ops.newaxis]
+ rt_newaxis4 = rt[:, :, :, :, array_ops.newaxis]
- self.assertEqual(rt.tolist(),
- [[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
- self.assertEqual(
- rt_newaxis0.tolist(),
- [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
- self.assertEqual(
- rt_newaxis1.tolist(),
- [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]]], [[]]])
- self.assertEqual(
- rt_newaxis2.tolist(),
- [[[[[b'a', b'b'], [b'c', b'd']]], [[]], [[[b'e', b'f']]]], []])
- self.assertEqual(
- rt_newaxis3.tolist(),
- [[[[[b'a', b'b']], [[b'c', b'd']]], [], [[[b'e', b'f']]]], []])
- self.assertEqual(
- rt_newaxis4.tolist(),
- [[[[[b'a'], [b'b']], [[b'c'], [b'd']]], [], [[[b'e'], [b'f']]]], []])
+ self.assertEqual(
+ self.evaluate(rt).tolist(),
+ [[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
+ self.assertEqual(
+ self.evaluate(rt_newaxis0).tolist(),
+ [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
+ self.assertEqual(
+ self.evaluate(rt_newaxis1).tolist(),
+ [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]]], [[]]])
+ self.assertEqual(
+ self.evaluate(rt_newaxis2).tolist(),
+ [[[[[b'a', b'b'], [b'c', b'd']]], [[]], [[[b'e', b'f']]]], []])
+ self.assertEqual(
+ self.evaluate(rt_newaxis3).tolist(),
+ [[[[[b'a', b'b']], [[b'c', b'd']]], [], [[[b'e', b'f']]]], []])
+ self.assertEqual(
+ self.evaluate(rt_newaxis4).tolist(),
+ [[[[[b'a'], [b'b']], [[b'c'], [b'd']]], [], [[[b'e'], [b'f']]]], []])
- self.assertEqual(rt.ragged_rank, 2)
- self.assertEqual(rt_newaxis0.ragged_rank, 3)
- self.assertEqual(rt_newaxis1.ragged_rank, 3)
- self.assertEqual(rt_newaxis2.ragged_rank, 3)
- self.assertEqual(rt_newaxis3.ragged_rank, 2)
- self.assertEqual(rt_newaxis4.ragged_rank, 2)
+ self.assertEqual(rt.ragged_rank, 2)
+ self.assertEqual(rt_newaxis0.ragged_rank, 3)
+ self.assertEqual(rt_newaxis1.ragged_rank, 3)
+ self.assertEqual(rt_newaxis2.ragged_rank, 3)
+ self.assertEqual(rt_newaxis3.ragged_rank, 2)
+ self.assertEqual(rt_newaxis4.ragged_rank, 2)
- self.assertEqual(rt_newaxis0.shape.as_list(), [1, None, None, None, 2])
- self.assertEqual(rt_newaxis1.shape.as_list(), [2, None, None, None, 2])
- self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, None, None, 2])
- self.assertEqual(rt_newaxis3.shape.as_list(), [2, None, None, 1, 2])
- self.assertEqual(rt_newaxis4.shape.as_list(), [2, None, None, 2, 1])
+ self.assertEqual(rt_newaxis0.shape.as_list(), [1, None, None, None, 2])
+ self.assertEqual(rt_newaxis1.shape.as_list(), [2, None, None, None, 2])
+ self.assertEqual(rt_newaxis2.shape.as_list(), [2, None, None, None, 2])
+ self.assertEqual(rt_newaxis3.shape.as_list(), [2, None, None, 1, 2])
+ self.assertEqual(rt_newaxis4.shape.as_list(), [2, None, None, 2, 1])
#=============================================================================
# RaggedTensor.__str__
@@ -1133,13 +1134,15 @@
rt2_times_10 = rt2.with_inner_values(rt2.inner_values * 10)
rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1))
- with self.test_session():
- self.assertEqual(rt1_plus_10.tolist(),
- [[11, 12], [13, 14, 15], [16], [], [17]])
- self.assertEqual(rt2_times_10.tolist(),
- [[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
- self.assertEqual(rt1_expanded.tolist(),
- [[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
+ self.assertEqual(
+ self.evaluate(rt1_plus_10).tolist(),
+ [[11, 12], [13, 14, 15], [16], [], [17]])
+ self.assertEqual(
+ self.evaluate(rt2_times_10).tolist(),
+ [[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
+ self.assertEqual(
+ self.evaluate(rt1_expanded).tolist(),
+ [[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
#=============================================================================
# Session.run
diff --git a/tensorflow/python/ops/ragged/ragged_tile_op_test.py b/tensorflow/python/ops/ragged/ragged_tile_op_test.py
index bf62d96..672d212 100644
--- a/tensorflow/python/ops/ragged/ragged_tile_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_tile_op_test.py
@@ -170,6 +170,15 @@
rt_input=[[[[1], [2]], [[3]]], [[]], [[[4, 5]]]],
multiples=[1, 1, 1, 0],
expected=[[[[], []], [[]]], [[]], [[[]]]]),
+ #=========================================================================
+ # multiple=1
+ #=========================================================================
+ dict(
+ descr='rank=4, multiples=1 (no repeats)',
+ rt_input=[[[[1], [2]], [[3], [4]]], [[[5], [6]]]],
+ multiples=[1, 1, 1, 1],
+ expected=[[[[1], [2]], [[3], [4]]],
+ [[[5], [6]]]]),
]) # pyformat: disable
def testRaggedTile(self,
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index c893ef0..f2df87c 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -138,7 +138,9 @@
return rnd
-@tf_export("random.truncated_normal", "truncated_normal")
+@tf_export("random.truncated_normal",
+ v1=["random.truncated_normal", "truncated_normal"])
+@deprecation.deprecated_endpoints("truncated_normal")
def truncated_normal(shape,
mean=0.0,
stddev=1.0,
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index c20f8fb..5c74dff 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -805,16 +805,6 @@
return ResourceVariable(
variable_def=variable_def, import_scope=import_scope)
- @staticmethod
- def _OverloadAllOperators(): # pylint: disable=invalid-name
- """Register overloads for all operators."""
- for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
- ResourceVariable._OverloadOperator(operator)
- # For slicing, bind getitem differently than a tensor (use SliceHelperVar
- # instead)
- # pylint: disable=protected-access
- setattr(ResourceVariable, "__getitem__", array_ops._SliceHelperVar)
-
def _AsTensor(self):
return self.value()
@@ -826,30 +816,6 @@
"""Unsupported."""
raise NotImplementedError("ResourceVariable does not implement set_shape()")
- @staticmethod
- def _OverloadOperator(operator): # pylint: disable=invalid-name
- """Defer an operator overload to `ops.Tensor`.
-
- We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
-
- Args:
- operator: string. The operator name.
- """
-
- tensor_oper = getattr(ops.Tensor, operator)
- def _run_op(a, *args):
- # pylint: disable=protected-access
- value = a._AsTensor()
- return tensor_oper(value, *args)
-
- # Propagate __doc__ to wrapper
- try:
- _run_op.__doc__ = tensor_oper.__doc__
- except AttributeError:
- pass
-
- setattr(ResourceVariable, operator, _run_op)
-
__array_priority__ = 100
def is_initialized(self, name=None):
@@ -1435,7 +1401,6 @@
variables.Variable, variables.Variable._TensorConversionFunction) # pylint: disable=protected-access
# pylint: disable=protected-access
-ResourceVariable._OverloadAllOperators()
ops.register_dense_tensor_like_type(ResourceVariable)
@@ -1525,3 +1490,6 @@
new_variable._maybe_initialize_checkpointable()
# pylint: enable=protected-access
return new_variable
+
+ops.NotDifferentiable("VarIsInitializedOp")
+ops.NotDifferentiable("VariableShape")
diff --git a/tensorflow/python/ops/resources.py b/tensorflow/python/ops/resources.py
index db67406..86477c9 100644
--- a/tensorflow/python/ops/resources.py
+++ b/tensorflow/python/ops/resources.py
@@ -21,6 +21,7 @@
from __future__ import print_function
import collections
+import os
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -86,7 +87,9 @@
resource_list = shared_resources() + local_resources()
with ops.name_scope(name):
# Run all operations on CPU
- with ops.device("/cpu:0"):
+ local_device = os.environ.get(
+ "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0")
+ with ops.device(local_device):
if not resource_list:
# Return an empty tensor so we only need to check for returned tensor
# size being 0 as an indication of model ready.
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 57ecb50..ec48cab 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -117,7 +117,7 @@
inferred_dtypes = [element.dtype for element in nest.flatten(state)]
if not inferred_dtypes:
raise ValueError("Unable to infer dtype from empty state.")
- all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes])
+ all_same = all(x == inferred_dtypes[0] for x in inferred_dtypes)
if not all_same:
raise ValueError(
"State has tensors of different inferred_dtypes. Unable to infer a "
@@ -348,7 +348,10 @@
return results
-@tf_export("nn.bidirectional_dynamic_rnn")
+@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
+ "keras.layers.RNN(cell))`, which is equivalent to "
+ "this API")
+@tf_export(v1=["nn.bidirectional_dynamic_rnn"])
def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
initial_state_fw=None, initial_state_bw=None,
dtype=None, parallel_iterations=None,
@@ -1490,7 +1493,10 @@
return (outputs, state)
-@tf_export("nn.static_bidirectional_rnn")
+@deprecation.deprecated(None, "Please use `keras.layers.Bidirectional("
+ "keras.layers.RNN(cell, unroll=True))`, which is "
+ "equivalent to this API")
+@tf_export(v1=["nn.static_bidirectional_rnn"])
def static_bidirectional_rnn(cell_fw,
cell_bw,
inputs,
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 050b486..ffc4561 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -36,6 +36,7 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
+from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
@@ -410,7 +411,7 @@
"performance on GPU.", self)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self._num_units = num_units
if activation:
@@ -507,7 +508,7 @@
"Please use tf.contrib.cudnn_rnn.CudnnGRU for better "
"performance on GPU.", self)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self._num_units = num_units
if activation:
@@ -683,7 +684,7 @@
"performance on GPU.", self)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self._num_units = num_units
self._forget_bias = forget_bias
@@ -871,7 +872,7 @@
"performance on GPU.", self)
# Inputs must be 2-dimensional.
- self.input_spec = base_layer.InputSpec(ndim=2)
+ self.input_spec = input_spec.InputSpec(ndim=2)
self._num_units = num_units
self._use_peepholes = use_peepholes
@@ -1394,7 +1395,7 @@
return self._cell(inputs, state, scope=scope)
-@tf_export("nn.rnn_cell.MultiRNNCell")
+@tf_export(v1=["nn.rnn_cell.MultiRNNCell"])
class MultiRNNCell(RNNCell):
"""RNN cell composed sequentially of multiple simple cells.
@@ -1407,6 +1408,9 @@
```
"""
+ @deprecated(None, "This class is equivalent as "
+ "tf.keras.layers.StackedRNNCells, and will be replaced by "
+ "that in Tensorflow 2.0.")
def __init__(self, cells, state_is_tuple=True):
"""Create a RNN cell composed sequentially of a number of RNNCells.
@@ -1452,7 +1456,7 @@
if self._state_is_tuple:
return tuple(cell.state_size for cell in self._cells)
else:
- return sum([cell.state_size for cell in self._cells])
+ return sum(cell.state_size for cell in self._cells)
@property
def output_size(self):
diff --git a/tensorflow/python/ops/sort_ops.py b/tensorflow/python/ops/sort_ops.py
new file mode 100644
index 0000000..c3e23d7
--- /dev/null
+++ b/tensorflow/python/ops/sort_ops.py
@@ -0,0 +1,197 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Support for sorting tensors.
+
+@@argsort
+@@sort
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops as framework_ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('sort')
+def sort(values, axis=-1, direction='ASCENDING', name=None):
+ """Sorts a tensor.
+
+ Args:
+ values: 1-D or higher numeric `Tensor`.
+ axis: The axis along which to sort. The default is -1, which sorts the last
+ axis.
+ direction: The direction in which to sort the values (`'ASCENDING'` or
+ `'DESCENDING'`).
+ name: Optional name for the operation.
+
+ Returns:
+ A `Tensor` with the same dtype and shape as `values`, with the elements
+ sorted along the given `axis`.
+
+ Raises:
+ ValueError: If axis is not a constant scalar, or the direction is invalid.
+ """
+ with framework_ops.name_scope(name, 'sort'):
+ return _sort_or_argsort(values, axis, direction, return_argsort=False)
+
+
+@tf_export('argsort')
+def argsort(values, axis=-1, direction='ASCENDING', stable=False, name=None):
+ """Returns the indices of a tensor that give its sorted order along an axis.
+
+ For a 1D tensor, `tf.gather(values, tf.argsort(values))` is equivalent to
+ `tf.sort(values)`. For higher dimensions, the output has the same shape as
+ `values`, but along the given axis, values represent the index of the sorted
+ element in that slice of the tensor at the given position.
+
+ Args:
+ values: 1-D or higher numeric `Tensor`.
+ axis: The axis along which to sort. The default is -1, which sorts the last
+ axis.
+ direction: The direction in which to sort the values (`'ASCENDING'` or
+ `'DESCENDING'`).
+ stable: If True, equal elements in the original tensor will not be
+ re-ordered in the returned order. Unstable sort is not yet implemented,
+ but will eventually be the default for performance reasons. If you require
+ a stable order, pass `stable=True` for forwards compatibility.
+ name: Optional name for the operation.
+
+ Returns:
+ An int32 `Tensor` with the same shape as `values`. The indices that would
+ sort each slice of the given `values` along the given `axis`.
+
+ Raises:
+ ValueError: If axis is not a constant scalar, or the direction is invalid.
+ """
+ del stable # Unused.
+ with framework_ops.name_scope(name, 'argsort'):
+ return _sort_or_argsort(values, axis, direction, return_argsort=True)
+
+
+def _sort_or_argsort(values, axis, direction, return_argsort):
+ """Internal sort/argsort implementation.
+
+ Args:
+ values: The input values.
+ axis: The axis along which to sort.
+ direction: 'ASCENDING' or 'DESCENDING'.
+ return_argsort: Whether to return the argsort result.
+
+ Returns:
+ Either the sorted values, or the indices of the sorted values in the
+ original tensor. See the `sort` and `argsort` docstrings.
+
+ Raises:
+ ValueError: If axis is not a constant scalar, or the direction is invalid.
+ """
+ if direction not in _SORT_IMPL:
+ raise ValueError('%s should be one of %s' % (direction, ', '.join(
+ sorted(_SORT_IMPL.keys()))))
+ # Axis must be an integer, not a Tensor.
+ axis = framework_ops.convert_to_tensor(axis, name='axis')
+ axis_static = tensor_util.constant_value(axis)
+ if axis.shape.ndims != 0 or axis_static is None:
+ raise ValueError('axis must be a constant scalar')
+ axis_static = int(axis_static) # Avoids NumPy casting error
+
+ values = framework_ops.convert_to_tensor(values, name='values')
+
+ return _SORT_IMPL[direction](values, axis_static, return_argsort)
+
+
+def _descending_sort(values, axis, return_argsort=False):
+ """Sorts values in reverse using `top_k`.
+
+ Args:
+ values: Tensor of numeric values.
+ axis: Index of the axis which values should be sorted along.
+ return_argsort: If False, return the sorted values. If True, return the
+ indices that would sort the values.
+
+ Returns:
+ The sorted values.
+ """
+ k = array_ops.shape(values)[axis]
+ rank = array_ops.rank(values)
+ static_rank = values.shape.ndims
+ # Fast path: sorting the last axis.
+ if axis == -1 or axis + 1 == values.get_shape().ndims:
+ top_k_input = values
+ transposition = None
+ else:
+ # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`.
+ if axis < 0:
+ # Calculate the actual axis index if counting from the end. Use the static
+ # rank if available, or else make the axis back into a tensor.
+ axis += static_rank or rank
+ if static_rank is not None:
+ # Prefer to calculate the transposition array in NumPy and make it a
+ # constant.
+ transposition = constant_op.constant(
+ np.r_[
+ # Axes up to axis are unchanged.
+ np.arange(axis),
+ # Swap axis and rank - 1.
+ [static_rank - 1],
+ # Axes in [axis + 1, rank - 1) are unchanged.
+ np.arange(axis + 1, static_rank - 1),
+ # Swap axis and rank - 1.
+ [axis]],
+ name='transposition')
+ else:
+ # Generate the transposition array from the tensors.
+ transposition = array_ops.concat(
+ [
+ # Axes up to axis are unchanged.
+ math_ops.range(axis),
+ # Swap axis and rank - 1.
+ [rank - 1],
+ # Axes in [axis + 1, rank - 1) are unchanged.
+ math_ops.range(axis + 1, rank - 1),
+ # Swap axis and rank - 1.
+ [axis]
+ ],
+ axis=0)
+ top_k_input = array_ops.transpose(values, transposition)
+
+ values, indices = nn_ops.top_k(top_k_input, k)
+ return_value = indices if return_argsort else values
+ if transposition is not None:
+ # transposition contains a single cycle of length 2 (swapping 2 elements),
+ # so it is an involution (it is its own inverse).
+ return_value = array_ops.transpose(return_value, transposition)
+ return return_value
+
+
+def _ascending_sort(values, axis, return_argsort=False):
+ # Negate the values to get the ascending order from descending sort.
+ values_or_indices = _descending_sort(-values, axis, return_argsort)
+ # If not argsort, negate the values again.
+ return values_or_indices if return_argsort else -values_or_indices
+
+
+_SORT_IMPL = {
+ 'ASCENDING': _ascending_sort,
+ 'DESCENDING': _descending_sort,
+}
diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/python/ops/sort_ops_test.py
similarity index 96%
rename from tensorflow/contrib/framework/python/ops/sort_ops_test.py
rename to tensorflow/python/ops/sort_ops_test.py
index 791b32c..8a92f49 100644
--- a/tensorflow/contrib/framework/python/ops/sort_ops_test.py
+++ b/tensorflow/python/ops/sort_ops_test.py
@@ -20,7 +20,6 @@
import numpy as np
-from tensorflow.contrib.framework.python.ops import sort_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -28,6 +27,7 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import sort_ops
from tensorflow.python.platform import test
@@ -88,9 +88,7 @@
self.assertAllEqual(
np.sort(arr, axis=0)[::-1],
sort_ops.sort(
- constant_op.constant(arr),
- axis=0,
- direction='DESCENDING').eval())
+ constant_op.constant(arr), axis=0, direction='DESCENDING').eval())
def testSort_staticallyKnownRank_constantTransposition(self):
# The transposition array should be a constant if the rank of "values" is
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index 1223b29..2ca9c0c 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -195,7 +195,7 @@
parts_a = array_ops.gather(grad, rows if not adj_a else cols)
parts_b = array_ops.gather(b if not adj_b else array_ops.transpose(b),
cols if not adj_a else rows)
- a_values_grad = math_ops.reduce_sum(parts_a * parts_b, reduction_indices=1)
+ a_values_grad = math_ops.reduce_sum(parts_a * parts_b, axis=1)
# gradients w.r.t. (a_indices, a_values, a_shape, b)
return (None, a_values_grad, None, b_grad)
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 58cd829..91baa6f 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -939,7 +939,7 @@
output_shape)
-@tf_export("sparse_to_dense")
+@tf_export(v1=["sparse_to_dense"])
@deprecation.deprecated(
None,
"Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.")
@@ -1849,8 +1849,7 @@
dense_shape=sp_input.dense_shape), empty_row_indicator)
-@tf_export(
- "io.serialize_sparse", v1=["io.serialize_sparse", "serialize_sparse"])
+@tf_export(v1=["io.serialize_sparse", "serialize_sparse"])
@deprecation.deprecated_endpoints("serialize_sparse")
def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
@@ -1867,6 +1866,25 @@
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
+ return serialize_sparse_v2(sp_input, out_type, name)
+
+
+@tf_export("io.serialize_sparse", v1=[])
+def serialize_sparse_v2(sp_input, out_type=dtypes.string, name=None):
+ """Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
+
+ Args:
+ sp_input: The input `SparseTensor`.
+ out_type: The `dtype` to use for serialization.
+ name: A name prefix for the returned tensors (optional).
+
+ Returns:
+ A 3-vector (1-D `Tensor`), with each column representing the serialized
+ `SparseTensor`'s indices, values, and shape (respectively).
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
sp_input = _convert_to_sparse_tensor(sp_input)
return gen_sparse_ops.serialize_sparse(
@@ -1877,9 +1895,7 @@
out_type=out_type)
-@tf_export(
- "io.serialize_many_sparse",
- v1=["io.serialize_many_sparse", "serialize_many_sparse"])
+@tf_export(v1=["io.serialize_many_sparse", "serialize_many_sparse"])
@deprecation.deprecated_endpoints("serialize_many_sparse")
def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
@@ -1905,6 +1921,34 @@
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
"""
+ return serialize_many_sparse_v2(sp_input, out_type, name)
+
+
+@tf_export("io.serialize_many_sparse", v1=[])
+def serialize_many_sparse_v2(sp_input, out_type=dtypes.string, name=None):
+ """Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
+
+ The `SparseTensor` must have rank `R` greater than 1, and the first dimension
+ is treated as the minibatch dimension. Elements of the `SparseTensor`
+ must be sorted in increasing order of this first dimension. The serialized
+ `SparseTensor` objects going into each row of the output `Tensor` will have
+ rank `R-1`.
+
+ The minibatch size `N` is extracted from `sparse_shape[0]`.
+
+ Args:
+ sp_input: The input rank `R` `SparseTensor`.
+ out_type: The `dtype` to use for serialization.
+ name: A name prefix for the returned tensors (optional).
+
+ Returns:
+ A matrix (2-D `Tensor`) with `N` rows and `3` columns. Each column
+ represents serialized `SparseTensor`'s indices, values, and shape
+ (respectively).
+
+ Raises:
+ TypeError: If `sp_input` is not a `SparseTensor`.
+ """
sp_input = _convert_to_sparse_tensor(sp_input)
return gen_sparse_ops.serialize_many_sparse(
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index f44f694..21f4996 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -70,8 +70,7 @@
x = ops.convert_to_tensor(x, name='x')
# Note reduce_sum([]) = 0.
- log_prod_gamma_x = math_ops.reduce_sum(
- math_ops.lgamma(x), reduction_indices=[-1])
+ log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1])
# Note lgamma(0) = infinity, so if x = []
# log_gamma_sum_x = lgamma(0) = infinity, and
@@ -264,11 +263,11 @@
missing_indices = set(temp_axis_labels) - set(output_axis_labels)
if missing_indices:
- reduction_indices = [
+ axis = [
i for i, a in enumerate(temp_axis_labels)
if a not in output_axis_labels
]
- temp = math_ops.reduce_sum(temp, reduction_indices=reduction_indices)
+ temp = math_ops.reduce_sum(temp, axis=axis)
temp_axis_labels = ''.join(
a for a in temp_axis_labels if a in output_axis_labels)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 03e491a..c614d07 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -72,6 +72,7 @@
from tensorflow.python.ops.random_ops import *
from tensorflow.python.ops.script_ops import py_func
from tensorflow.python.ops.session_ops import *
+from tensorflow.python.ops.sort_ops import *
from tensorflow.python.ops.sparse_ops import *
from tensorflow.python.ops.state_ops import assign
from tensorflow.python.ops.state_ops import assign_add
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index a20eec2..b2090da 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -311,7 +311,7 @@
return math_ops.range(array_ops.rank(x) - 1, -1, -1)
-@tf_export("strings.reduce_join", v1=["strings.reduce_join", "reduce_join"])
+@tf_export(v1=["strings.reduce_join", "reduce_join"])
@deprecation.deprecated_endpoints("reduce_join")
def reduce_join(inputs, axis=None, # pylint: disable=missing-docstring
keep_dims=False,
@@ -329,6 +329,17 @@
name=name)
+@tf_export("strings.reduce_join", v1=[])
+def reduce_join_v2( # pylint: disable=missing-docstring
+ inputs,
+ axis=None,
+ keepdims=False,
+ separator="",
+ name=None):
+ return reduce_join(
+ inputs, axis, keep_dims=keepdims, separator=separator, name=name)
+
+
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
@@ -357,11 +368,16 @@
substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
-@tf_export("strings.substr")
+@tf_export(v1=["strings.substr"])
def substr(input, pos, len, name=None, unit="BYTE"):
return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
+@tf_export("strings.substr", v1=[])
+def substr_v2(input, pos, len, unit="BYTE", name=None):
+ return substr(input, pos, len, name=name, unit=unit)
+
+
substr.__doc__ = gen_string_ops.substr.__doc__
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index a0ad43b..3f99b9f 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -58,7 +58,6 @@
_USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I)
-@tf_export("summary.should_record_summaries", v1=[])
def should_record_summaries():
"""Returns boolean Tensor which is true if summaries should be recorded."""
global _SHOULD_RECORD_SUMMARIES
@@ -67,9 +66,8 @@
return should() if callable(should) else should
-@tf_export("summary.record_summaries", v1=[])
@tf_contextlib.contextmanager
-def record_summaries(boolean=True):
+def _record_summaries(boolean=True):
"""Sets summary recording on or off per the provided boolean value.
The provided value can be a python boolean, a scalar boolean Tensor, or
@@ -105,17 +103,17 @@
should = lambda: math_ops.equal(global_step % n, 0)
if not context.executing_eagerly():
should = should()
- return record_summaries(should)
+ return _record_summaries(should)
def always_record_summaries():
"""Sets the should_record_summaries Tensor to always true."""
- return record_summaries(True)
+ return _record_summaries(True)
def never_record_summaries():
"""Sets the should_record_summaries Tensor to always false."""
- return record_summaries(False)
+ return _record_summaries(False)
@tf_export("summary.SummaryWriter", v1=[])
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 077bb64..ad81862 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -680,7 +680,7 @@
"Partitioner returned a partition list that does not match the "
"Variable's rank: %s vs. %s" % (partitions, shape))
- if any([p < 1 for p in partitions]):
+ if any(p < 1 for p in partitions):
raise ValueError(
"Partitioner returned zero partitions for some axes: %s" %
partitions)
@@ -799,15 +799,13 @@
vs.append(var)
# pylint: enable=protected-access
- # pylint: disable=protected-access
partitioned_var = variables.PartitionedVariable(name=name,
shape=shape,
dtype=dtype,
variable_list=vs,
partitions=partitions)
- # pylint: enable=protected-access
-
- self._partitioned_vars[name] = partitioned_var
+ if not context.executing_eagerly() or self._store_eager_variables:
+ self._partitioned_vars[name] = partitioned_var
return partitioned_var
def _get_single_variable(self,
@@ -909,20 +907,22 @@
variable_dtype = None
else:
# Instantiate initializer if provided initializer is a type object.
- if isinstance(initializer, type(init_ops.Initializer)):
+ if tf_inspect.isclass(initializer):
initializer = initializer(dtype=dtype)
- if shape and shape.is_fully_defined():
+ if shape is not None and shape.is_fully_defined():
init_val = lambda: initializer( # pylint: disable=g-long-lambda
shape.as_list(), dtype=dtype, partition_info=partition_info)
- elif not tf_inspect.getargspec(initializer).args:
+ variable_dtype = dtype.base_dtype
+ elif len(tf_inspect.getargspec(initializer).args) == len(
+ tf_inspect.getargspec(initializer).defaults or []):
init_val = initializer
+ variable_dtype = None
else:
- raise ValueError("You can only pass an initializer function that "
- "expects no arguments to its callable when the "
- "shape is not fully defined. The given initializer "
- "function expects the following args %s" %
- tf_inspect.getargspec(initializer).args)
- variable_dtype = dtype.base_dtype
+ raise ValueError("The initializer passed is not valid. It should "
+ "be a callable with no arguments and the "
+ "shape should not be provided or an instance of "
+ "`tf.keras.initializers.*' and `shape` should be "
+ "fully defined.")
# Create the variable.
if use_resource is None:
@@ -2233,8 +2233,8 @@
try:
return self._enter_scope_uncached()
- except:
- if not self._building_function:
+ except Exception:
+ if self._in_graph_mode and not self._building_function:
if self._graph_context_manager is not None:
self._graph_context_manager.__exit__(*sys.exc_info())
raise
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 5bee522..4824c92 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -18,7 +18,8 @@
from __future__ import print_function
import enum # pylint: disable=g-bad-import-order
-
+import functools
+import os
import six
from tensorflow.core.framework import attr_value_pb2
@@ -860,18 +861,18 @@
else:
return v.value()
- @staticmethod
- def _OverloadAllOperators(): # pylint: disable=invalid-name
+ @classmethod
+ def _OverloadAllOperators(cls): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
- Variable._OverloadOperator(operator)
+ cls._OverloadOperator(operator)
# For slicing, bind getitem differently than a tensor (use SliceHelperVar
# instead)
# pylint: disable=protected-access
- setattr(Variable, "__getitem__", array_ops._SliceHelperVar)
+ setattr(cls, "__getitem__", array_ops._SliceHelperVar)
- @staticmethod
- def _OverloadOperator(operator): # pylint: disable=invalid-name
+ @classmethod
+ def _OverloadOperator(cls, operator): # pylint: disable=invalid-name
"""Defer an operator overload to `ops.Tensor`.
We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
@@ -879,17 +880,26 @@
Args:
operator: string. The operator name.
"""
+ tensor_oper = getattr(ops.Tensor, operator)
- def _run_op(a, *args):
+ def _run_op(a, *args, **kwargs):
# pylint: disable=protected-access
- return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
- # Propagate __doc__ to wrapper
- try:
- _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
- except AttributeError:
- pass
+ return tensor_oper(a._AsTensor(), *args, **kwargs)
- setattr(Variable, operator, _run_op)
+ functools.update_wrapper(_run_op, tensor_oper)
+ setattr(cls, operator, _run_op)
+
+ def __iter__(self):
+ """Dummy method to prevent iteration. Do not call.
+
+ NOTE(mrry): If we register __getitem__ as an overloaded operator,
+ Python will valiantly attempt to iterate over the variable's Tensor from 0
+ to infinity. Declaring this method prevents this unintended behavior.
+
+ Raises:
+ TypeError: when invoked.
+ """
+ raise TypeError("'Variable' object is not iterable.")
# NOTE(mrry): This enables the Variable's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
@@ -1045,27 +1055,6 @@
else:
return None
- def __iadd__(self, other):
- raise NotImplementedError
-
- def __isub__(self, other):
- raise NotImplementedError
-
- def __imul__(self, other):
- raise NotImplementedError
-
- def __idiv__(self, other):
- raise NotImplementedError
-
- def __itruediv__(self, other):
- raise NotImplementedError
-
- def __irealdiv__(self, other):
- raise NotImplementedError
-
- def __ipow__(self, other):
- raise NotImplementedError
-
@tf_export(v1=["Variable"])
class VariableV1(Variable):
@@ -1576,18 +1565,6 @@
"""
return self._snapshot
- def __iter__(self):
- """Dummy method to prevent iteration. Do not call.
-
- NOTE(mrry): If we register __getitem__ as an overloaded operator,
- Python will valiantly attempt to iterate over the variable's Tensor from 0
- to infinity. Declaring this method prevents this unintended behavior.
-
- Raises:
- TypeError: when invoked.
- """
- raise TypeError("'Variable' object is not iterable.")
-
def value(self):
"""Returns the last snapshot of this variable.
@@ -2123,37 +2100,6 @@
else:
return v.value()
- @staticmethod
- def _OverloadAllOperators(): # pylint: disable=invalid-name
- """Register overloads for all operators."""
- for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
- Variable._OverloadOperator(operator) # pylint: disable=protected-access
- # For slicing, bind getitem differently than a tensor (use SliceHelperVar
- # instead)
- # pylint: disable=protected-access
- setattr(Variable, "__getitem__", array_ops._SliceHelperVar)
-
- @staticmethod
- def _OverloadOperator(operator): # pylint: disable=invalid-name
- """Defer an operator overload to `ops.Tensor`.
-
- We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
-
- Args:
- operator: string. The operator name.
- """
-
- def _run_op(a, *args):
- # pylint: disable=protected-access
- return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
- # Propagate __doc__ to wrapper
- try:
- _run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
- except AttributeError:
- pass
-
- setattr(Variable, operator, _run_op)
-
def _gather_saveables_for_checkpoint(self):
"""For implementing `Checkpointable`. This object is saveable on its own."""
return {checkpointable.VARIABLE_VALUE_KEY: self}
@@ -2482,21 +2428,21 @@
"variable_list is not a list or tuple: %s" % variable_list)
if not isinstance(partitions, (list, tuple)):
raise TypeError("partitions is not a list or tuple: %s" % partitions)
- if not all([p >= 1 for p in partitions]):
+ if not all(p >= 1 for p in partitions):
raise ValueError("partition values must be positive: %s" % partitions)
if not variable_list:
raise ValueError("variable_list may not be empty")
# pylint: disable=protected-access
for v in variable_list:
# Sort the variable_list lexicographically according to var offset value.
- if not all([v._get_save_slice_info() is not None for v in variable_list]):
+ if not all(v._get_save_slice_info() is not None for v in variable_list):
raise ValueError(
"All variables must have a save_slice_info available: %s"
% [v.name for v in variable_list])
if len(shape) != len(partitions):
raise ValueError("len(shape) != len(partitions): %s vs. %s"
% (shape, partitions))
- if not all([v._get_save_slice_info().full_shape == shape]):
+ if v._get_save_slice_info().full_shape != shape:
raise ValueError(
"All variables' full shapes must match shape: %s; "
"but full shapes were: %s"
@@ -2523,7 +2469,7 @@
return len(self._variable_list)
def _partition_axes(self):
- if all([p == 1 for p in self._partitions]):
+ if all(p == 1 for p in self._partitions):
return [0]
else:
return [i for i, p in enumerate(self._partitions) if p > 1]
@@ -2963,7 +2909,9 @@
# Run all operations on CPU
if var_list:
init_vars = [state_ops.is_variable_initialized(v) for v in var_list]
- with ops.device("/cpu:0"):
+ local_device = os.environ.get(
+ "TF_DEVICE_FOR_UNINITIALIZED_VARIABLE_REPORTING", "/cpu:0")
+ with ops.device(local_device):
if not var_list:
# Return an empty tensor so we only need to check for returned tensor
# size being 0 as an indication of model ready.
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 5ab7bff..6821b63 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -509,7 +509,7 @@
# TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there
# is a tf.StopGradient in the loop body.
- assert all([g is not None for g in grad_outs])
+ assert all(g is not None for g in grad_outs)
counter = args[0]
total_iters = args[1]
return [counter + 1, total_iters] + grad_outs
@@ -839,6 +839,10 @@
# TODO(b/118452219): add test coverage for this.
tensor = func_graph_module.maybe_captured(tensor)
+ if isinstance(tensor, ops.EagerTensor):
+ # Eager execution doesn't quite support legacy tensorarray
+ return False
+
return tensor.op.type in TENSOR_ARRAY_HANDLE_OPS
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index 4f7abb3..d6773d7 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -30,6 +30,7 @@
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.util import test_log_pb2
from tensorflow.python.client import timeline
+from tensorflow.python.framework import ops
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
@@ -299,6 +300,18 @@
benchmark_values["extras"].update(unreported_extras)
return benchmark_values
+ def evaluate(self, tensors):
+ """Evaluates tensors and returns numpy values.
+
+ Args:
+ tensors: A Tensor or a nested list/tuple of Tensors.
+
+ Returns:
+ tensors numpy values.
+ """
+ sess = ops.get_default_session() or self.cached_session()
+ return sess.run(tensors)
+
def _run_benchmarks(regex):
"""Run benchmarks that match regex `regex`.
diff --git a/tensorflow/python/platform/tf_logging.py b/tensorflow/python/platform/tf_logging.py
index 9f00abb..813bcb8 100644
--- a/tensorflow/python/platform/tf_logging.py
+++ b/tensorflow/python/platform/tf_logging.py
@@ -37,7 +37,7 @@
from tensorflow.python.util.tf_export import tf_export
-# Don't use this directly. Use _get_logger() instead.
+# Don't use this directly. Use get_logger() instead.
_logger = None
_logger_lock = threading.Lock()
@@ -78,7 +78,8 @@
return '(unknown file)', 0, '(unknown function)'
-def _get_logger():
+@tf_export('get_logger')
+def get_logger():
"""Return TF logger instance."""
global _logger
@@ -132,37 +133,37 @@
@tf_export(v1=['logging.log'])
def log(level, msg, *args, **kwargs):
- _get_logger().log(level, msg, *args, **kwargs)
+ get_logger().log(level, msg, *args, **kwargs)
@tf_export(v1=['logging.debug'])
def debug(msg, *args, **kwargs):
- _get_logger().debug(msg, *args, **kwargs)
+ get_logger().debug(msg, *args, **kwargs)
@tf_export(v1=['logging.error'])
def error(msg, *args, **kwargs):
- _get_logger().error(msg, *args, **kwargs)
+ get_logger().error(msg, *args, **kwargs)
@tf_export(v1=['logging.fatal'])
def fatal(msg, *args, **kwargs):
- _get_logger().fatal(msg, *args, **kwargs)
+ get_logger().fatal(msg, *args, **kwargs)
@tf_export(v1=['logging.info'])
def info(msg, *args, **kwargs):
- _get_logger().info(msg, *args, **kwargs)
+ get_logger().info(msg, *args, **kwargs)
@tf_export(v1=['logging.warn'])
def warn(msg, *args, **kwargs):
- _get_logger().warn(msg, *args, **kwargs)
+ get_logger().warn(msg, *args, **kwargs)
@tf_export(v1=['logging.warning'])
def warning(msg, *args, **kwargs):
- _get_logger().warning(msg, *args, **kwargs)
+ get_logger().warning(msg, *args, **kwargs)
_level_names = {
@@ -196,7 +197,7 @@
# Code below is taken from pyglib/logging
@tf_export(v1=['logging.vlog'])
def vlog(level, msg, *args, **kwargs):
- _get_logger().log(level, msg, *args, **kwargs)
+ get_logger().log(level, msg, *args, **kwargs)
def _GetNextLogCountPerToken(token):
@@ -299,13 +300,13 @@
@tf_export(v1=['logging.get_verbosity'])
def get_verbosity():
"""Return how much logging output will be produced."""
- return _get_logger().getEffectiveLevel()
+ return get_logger().getEffectiveLevel()
@tf_export(v1=['logging.set_verbosity'])
def set_verbosity(v):
"""Sets the threshold for what messages will be logged."""
- _get_logger().setLevel(v)
+ get_logger().setLevel(v)
def _get_thread_id():
diff --git a/tensorflow/python/profiler/profile_context_test.py b/tensorflow/python/profiler/profile_context_test.py
index abbeb8b..680cd71 100644
--- a/tensorflow/python/profiler/profile_context_test.py
+++ b/tensorflow/python/profiler/profile_context_test.py
@@ -51,7 +51,7 @@
self.evaluate(variables.global_variables_initializer())
total_steps = 101
for i in range(total_steps):
- sess.run(x)
+ self.evaluate(x)
if i == 14 or i == 49:
self.assertTrue(gfile.Exists(outfile))
gfile.Remove(outfile)
@@ -77,16 +77,16 @@
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(10):
- sess.run(x)
+ self.evaluate(x)
for f in gfile.ListDirectory(test.get_temp_dir()):
# Warm up, no tracing.
self.assertFalse("run_meta" in f)
- sess.run(x)
+ self.evaluate(x)
self.assertTrue(
gfile.Exists(os.path.join(test.get_temp_dir(), "run_meta_11")))
gfile.Remove(os.path.join(test.get_temp_dir(), "run_meta_11"))
# fetched already.
- sess.run(x)
+ self.evaluate(x)
for f in gfile.ListDirectory(test.get_temp_dir()):
self.assertFalse("run_meta" in f)
@@ -98,7 +98,7 @@
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(10):
- sess.run(x)
+ self.evaluate(x)
self.assertTrue(pctx.profiler is None)
self.assertTrue(
getattr(session.BaseSession, "profile_context", None) is None)
@@ -107,7 +107,7 @@
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
for _ in range(10):
- sess.run(x)
+ self.evaluate(x)
self.assertFalse(pctx.profiler is None)
self.assertFalse(
getattr(session.BaseSession, "profile_context", None) is None)
diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py
index f696d48..1edc0c8 100644
--- a/tensorflow/python/saved_model/constants.py
+++ b/tensorflow/python/saved_model/constants.py
@@ -54,7 +54,7 @@
__name__, "MAIN_OP_KEY")
# CollectionDef key for the SavedModel train op.
-# Not exported while export_all_saved_models is in contrib.
+# Not exported while export_all_saved_models is experimental.
TRAIN_OP_KEY = "saved_model_train_op"
# Schema version for SavedModel.
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index 0b97a73..3678e50 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -104,9 +104,9 @@
with self.session(graph=graph) as sess:
# Check that x and y are not initialized
with self.assertRaises(errors.FailedPreconditionError):
- sess.run(x)
+ self.evaluate(x)
with self.assertRaises(errors.FailedPreconditionError):
- sess.run(y)
+ self.evaluate(y)
def test_load_with_import_scope(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 02c8dc7..d52251e 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -27,6 +27,7 @@
from tensorflow.python.eager import function
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
@@ -42,17 +43,50 @@
from tensorflow.python.util.tf_export import tf_export
+def _check_for_functional_keras_model(root):
+ """Makes an export signature for `root` if it's a functional Keras Model."""
+ # If nothing is decorated yet but this is a functional Keras Model (duck
+ # typed), we'll try to make a signature ourselves.
+ try:
+ inputs = root.inputs
+ input_names = root.input_names
+ except AttributeError:
+ return None
+ input_signature = []
+ for input_tensor, input_name in zip(inputs, input_names):
+ input_signature.append(tensor_spec.TensorSpec(
+ shape=input_tensor.shape, dtype=input_tensor.dtype,
+ name=input_name))
+
+ @def_function.function(input_signature=input_signature)
+ def _wrapped_model(*args):
+ outputs_list = nest.flatten(root(inputs=list(args)))
+ return {name: output for name, output
+ in zip(root.output_names, outputs_list)}
+ return _wrapped_model
+
+
def _find_function_to_export(root):
"""Iterate over `root`'s attributes, finding traced functions."""
- functions = []
- function_attribute_names = []
+ exported_function = None
+ previous_attribute_name = None
for attribute_name in dir(root):
attribute_value = getattr(root, attribute_name, None)
if isinstance(attribute_value, def_function.PolymorphicFunction):
- functions.append(attribute_value)
- function_attribute_names.append(attribute_name)
- # TODO(allenl): Automatically infer signatures for Keras functional models?
- if not functions:
+ if exported_function is not None:
+ raise ValueError(
+ ("Exporting an object with no "
+ "tf.saved_model.save(..., signatures=...) "
+ "argument specified, and with more than one "
+ "@tf.function-decorated method attached to it: {}. The signature "
+ "keys for these functions are ambiguous. Specify signature "
+ "functions explicitly.").format(
+ [previous_attribute_name, attribute_name]))
+ exported_function = attribute_value
+ previous_attribute_name = attribute_name
+ if exported_function is None:
+ exported_function = _check_for_functional_keras_model(root)
+ if exported_function is None:
raise ValueError(
("Exporting an object with no tf.saved_model.save(..., signatures=...) "
"argument specified, and with no @tf.function-decorated methods "
@@ -61,14 +95,7 @@
"signatures does not make sense, as the only consumers will expect "
"signatures. Either decorate a method or specify a signature function "
"explicitly."))
- elif len(functions) > 1:
- raise ValueError(
- ("Exporting an object with no tf.saved_model.save(..., signatures=...) "
- "argument specified, and with more than one @tf.function-decorated "
- "method attached to it: {}. The signature keys for these functions "
- "are ambiguous. Specify signature functions explicitly.").format(
- function_attribute_names))
- return functions[0]
+ return exported_function
def _canonicalize_signatures(signatures):
@@ -451,6 +478,19 @@
tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp")))
```
+ `tf.keras.Model` instances constructed from inputs and outputs already have a
+ signature and so do not require a `@tf.function` decorator or a `signatures`
+ argument. If neither are specified, the model's forward pass is exported.
+
+ ```python
+ x = input_layer.Input((4,), name="x")
+ y = core.Dense(5, name="out")(x)
+ model = training.Model(x, y)
+ tf.saved_model.save(model, '/tmp/saved_model/')
+ # The exported SavedModel takes "x" with shape [None, 4] and returns "out"
+ # with shape [None, 5]
+ ```
+
Variables must be tracked by assigning them to an attribute of a tracked
object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
from `tf.keras.layers`, optimizers from `tf.train`) track their variables
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index 04cd9d0..8fb2803 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -21,6 +21,8 @@
import os
import sys
+import numpy
+
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
@@ -29,8 +31,11 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import input_layer
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.layers import merge
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import loader
@@ -214,6 +219,19 @@
with self.assertRaisesRegexp(ValueError, "call.*second_function"):
save.save(model, save_dir)
+ def test_subclassed_no_signature(self):
+
+ class Subclassed(training.Model):
+
+ def call(self, inputs):
+ return inputs * 2.
+
+ save_dir = os.path.join(self.get_temp_dir(), "saved_model")
+ model = Subclassed()
+ with self.assertRaisesRegexp(
+ ValueError, "no @tf.function-decorated methods"):
+ save.save(model, save_dir)
+
def test_docstring(self):
class Adder(util.Checkpoint):
@@ -254,6 +272,45 @@
self.assertNotIn("T", complex_node.attr)
self.assertNotIn("Tout", complex_node.attr)
+ def test_export_functional_keras_model(self):
+ x = input_layer.Input((4,), name="x")
+ y = core.Dense(4, name="out")(x)
+ model = training.Model(x, y)
+ save_dir = os.path.join(self.get_temp_dir(), "saved_model")
+ save.save(model, save_dir)
+ self.assertAllClose(
+ {"out": model(array_ops.ones([1, 4]))},
+ self._import_and_infer(save_dir, {"x": [[1., 1., 1., 1.]]}))
+
+ def test_export_functional_keras_model_after_fit(self):
+ x = input_layer.Input((1,))
+ y = core.Dense(1, name="y")(x)
+ model = training.Model(x, y)
+ model.compile(optimizer="sgd", loss="mse")
+ model.fit(x=numpy.array([[1.]]),
+ y=numpy.array([2.]), epochs=2)
+ save_dir = os.path.join(self.get_temp_dir(), "saved_model")
+ save.save(model, save_dir)
+ self.assertAllClose(
+ {"y": model(constant_op.constant([[1.], [2.]]))},
+ self._import_and_infer(save_dir, {"input_1": [[1.], [2.]]}))
+
+ def test_export_multi_input_functional_keras_model(self):
+ x1 = input_layer.Input((2,), name="x1")
+ x2 = input_layer.Input((2,), name="x2")
+ y1 = core.Dense(4)(merge.Add()([x1, x2]))
+ y2 = core.Dense(4)(merge.Multiply()([x1, x2]))
+ model = training.Model([x1, x2], [y1, y2])
+ save_dir = os.path.join(self.get_temp_dir(), "saved_model")
+ save.save(model, save_dir)
+ outputs = model([array_ops.ones([1, 2]), 2. * array_ops.ones([1, 2])])
+ self.assertAllClose(
+ {"dense": outputs[0], "dense_1": outputs[1]},
+ self._import_and_infer(
+ save_dir,
+ {"x1": [[1., 1.]],
+ "x2": [[2., 2.]]}))
+
class MemoryTests(test.TestCase):
diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py
index 9646071..0efe176 100644
--- a/tensorflow/python/saved_model/signature_constants.py
+++ b/tensorflow/python/saved_model/signature_constants.py
@@ -135,7 +135,7 @@
################################################################################
# Train/Eval API constants.
-# Not exported while export_all_saved_models is in contrib.
+# Not exported while export_all_saved_models is experimental.
SUPERVISED_TRAIN_METHOD_NAME = "tensorflow/supervised/training"
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index b41a1bc..3517c11 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -14,7 +14,6 @@
"errors/__init__.py",
"experimental/__init__.py",
"feature_column/__init__.py",
- "gfile/__init__.py",
"io/gfile/__init__.py",
"graph_util/__init__.py",
"image/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index 0fadec0..e35b9c4 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -17,6 +17,7 @@
"experimental/__init__.py",
"feature_column/__init__.py",
"gfile/__init__.py",
+ "io/gfile/__init__.py",
"graph_util/__init__.py",
"image/__init__.py",
"io/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py
index 9e211d1..abb5886 100644
--- a/tensorflow/python/tools/api/generator/doc_srcs.py
+++ b/tensorflow/python/tools/api/generator/doc_srcs.py
@@ -37,7 +37,7 @@
'app': DocSource(docstring_module_name='platform.app'),
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
'compat': DocSource(docstring_module_name='util.compat'),
- 'distribute': DocSource(docstring_module_name='training.distribute'),
+ 'distribute': DocSource(docstring_module_name='distribute.distribute_lib'),
'distributions': DocSource(
docstring_module_name='ops.distributions.distributions'),
'errors': DocSource(docstring_module_name='framework.errors'),
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index 6504fbc..ea1f6aa 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -63,7 +63,7 @@
print("It's likely that your checkpoint file has been compressed "
"with SNAPPY.")
if ("Data loss" in str(e) and
- (any([e in file_name for e in [".index", ".meta", ".data"]]))):
+ any(e in file_name for e in [".index", ".meta", ".data"])):
proposed_file = ".".join(file_name.split(".")[0:-1])
v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide the filename
diff --git a/tensorflow/python/training/adadelta.py b/tensorflow/python/training/adadelta.py
index 95eca76..dd21016 100644
--- a/tensorflow/python/training/adadelta.py
+++ b/tensorflow/python/training/adadelta.py
@@ -25,7 +25,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.AdadeltaOptimizer")
+@tf_export(v1=["train.AdadeltaOptimizer"])
class AdadeltaOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adadelta algorithm.
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
index cc0da26..10c043b 100644
--- a/tensorflow/python/training/adagrad.py
+++ b/tensorflow/python/training/adagrad.py
@@ -28,7 +28,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.AdagradOptimizer")
+@tf_export(v1=["train.AdagradOptimizer"])
class AdagradOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adagrad algorithm.
diff --git a/tensorflow/python/training/adagrad_da.py b/tensorflow/python/training/adagrad_da.py
index 5ba4035..e23b713 100644
--- a/tensorflow/python/training/adagrad_da.py
+++ b/tensorflow/python/training/adagrad_da.py
@@ -26,7 +26,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.AdagradDAOptimizer")
+@tf_export(v1=["train.AdagradDAOptimizer"])
class AdagradDAOptimizer(optimizer.Optimizer):
"""Adagrad Dual Averaging algorithm for sparse linear models.
diff --git a/tensorflow/python/training/adagrad_da_test.py b/tensorflow/python/training/adagrad_da_test.py
index 761f703..c7c4720 100644
--- a/tensorflow/python/training/adagrad_da_test.py
+++ b/tensorflow/python/training/adagrad_da_test.py
@@ -54,14 +54,14 @@
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([0.0, 0.0], v0_val)
self.assertAllClose([0.0, 0.0], v1_val)
# Run a step of AdagradDA
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
# Let g to be gradient accumulator, gg to be gradient squared
# accumulator, T be the global step, lr is the learning rate, and k the
# initial gradient squared accumulator value.
@@ -119,14 +119,14 @@
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
# Run a step of AdagradDA
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-0.904534, -1.603567]), v0_val)
self.assertAllCloseAccordingToType(
@@ -151,14 +151,14 @@
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
# Run a step of AdagradDA
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-0.895489, -1.59555]), v0_val)
self.assertAllCloseAccordingToType(
@@ -183,14 +183,14 @@
zip([grads0, grads1], [var0, var1]), global_step=global_step)
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
# Run a step of AdagradDA
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-0.046907, -0.093659]), v0_val)
self.assertAllCloseAccordingToType(
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 704ad6d..0c701f4 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -29,7 +29,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.AdamOptimizer")
+@tf_export(v1=["train.AdamOptimizer"])
class AdamOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Adam algorithm.
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index 13c9e9a..03810b5 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -22,7 +22,6 @@
import os.path
import shutil
import tempfile
-import threading
import time
from tensorflow.contrib.framework.python.framework import checkpoint_utils
@@ -52,6 +51,11 @@
from tensorflow.python.training import training_util
+# Provide a realistic start time for unit tests where we need to mock out
+# calls to time.time().
+MOCK_START_TIME = 1484695987.209386
+
+
class MockCheckpointSaverListener(
basic_session_run_hooks.CheckpointSaverListener):
@@ -95,7 +99,9 @@
with self.assertRaises(ValueError):
basic_session_run_hooks.SecondOrStepTimer()
- def test_every_secs(self):
+ @test.mock.patch.object(time, 'time')
+ def test_every_secs(self, mock_time):
+ mock_time.return_value = MOCK_START_TIME
timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0)
self.assertTrue(timer.should_trigger_for_step(1))
@@ -103,7 +109,7 @@
self.assertFalse(timer.should_trigger_for_step(1))
self.assertFalse(timer.should_trigger_for_step(2))
- time.sleep(1.0)
+ mock_time.return_value += 1.0
self.assertFalse(timer.should_trigger_for_step(1))
self.assertTrue(timer.should_trigger_for_step(2))
@@ -314,7 +320,7 @@
# in first run, elapsed time is None.
self.assertEqual(str(self.logged_message).find('sec'), -1)
- def _validate_print_every_n_secs(self, sess, at_end):
+ def _validate_print_every_n_secs(self, sess, at_end, mock_time):
t = constant_op.constant(42.0, name='foo')
train_op = constant_op.constant(3)
@@ -331,7 +337,7 @@
self.logged_message = ''
mon_sess.run(train_op)
self.assertEqual(str(self.logged_message).find(t.name), -1)
- time.sleep(1.0)
+ mock_time.return_value += 1.0
self.logged_message = ''
mon_sess.run(train_op)
@@ -345,17 +351,21 @@
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self.logged_message).find(t.name), -1)
- def test_print_every_n_secs(self):
+ @test.mock.patch.object(time, 'time')
+ def test_print_every_n_secs(self, mock_time):
with ops.Graph().as_default(), session_lib.Session() as sess:
- self._validate_print_every_n_secs(sess, at_end=False)
+ mock_time.return_value = MOCK_START_TIME
+ self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time)
# Verify proper reset.
- self._validate_print_every_n_secs(sess, at_end=False)
+ self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time)
- def test_print_every_n_secs_and_end(self):
+ @test.mock.patch.object(time, 'time')
+ def test_print_every_n_secs_and_end(self, mock_time):
with ops.Graph().as_default(), session_lib.Session() as sess:
- self._validate_print_every_n_secs(sess, at_end=True)
+ mock_time.return_value = MOCK_START_TIME
+ self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time)
# Verify proper reset.
- self._validate_print_every_n_secs(sess, at_end=True)
+ self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time)
def test_print_formatter(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
@@ -562,11 +572,8 @@
@test.mock.patch.object(time, 'time')
def test_save_secs_saves_periodically(self, mock_time):
- # Let's have a realistic start time
- current_time = 1484695987.209386
-
with self.graph.as_default():
- mock_time.return_value = current_time
+ mock_time.return_value = MOCK_START_TIME
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_secs=2, scaffold=self.scaffold)
hook.begin()
@@ -576,10 +583,10 @@
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
- mock_time.return_value = current_time
+ mock_time.return_value = MOCK_START_TIME
mon_sess.run(self.train_op) # Saved.
- mock_time.return_value = current_time + 0.5
+ mock_time.return_value = MOCK_START_TIME + 0.5
mon_sess.run(self.train_op) # Not saved.
self.assertEqual(1,
@@ -587,13 +594,13 @@
self.global_step.name))
# Simulate 2.5 seconds of sleep.
- mock_time.return_value = current_time + 2.5
+ mock_time.return_value = MOCK_START_TIME + 2.5
mon_sess.run(self.train_op) # Saved.
- mock_time.return_value = current_time + 2.6
+ mock_time.return_value = MOCK_START_TIME + 2.6
mon_sess.run(self.train_op) # Not saved.
- mock_time.return_value = current_time + 2.7
+ mock_time.return_value = MOCK_START_TIME + 2.7
mon_sess.run(self.train_op) # Not saved.
self.assertEqual(3,
@@ -601,7 +608,7 @@
self.global_step.name))
# Simulate 7.5 more seconds of sleep (10 seconds from start.
- mock_time.return_value = current_time + 10
+ mock_time.return_value = MOCK_START_TIME + 10
mon_sess.run(self.train_op) # Saved.
self.assertEqual(6,
checkpoint_utils.load_variable(self.model_dir,
@@ -609,11 +616,8 @@
@test.mock.patch.object(time, 'time')
def test_save_secs_calls_listeners_periodically(self, mock_time):
- # Let's have a realistic start time
- current_time = 1484695987.209386
-
with self.graph.as_default():
- mock_time.return_value = current_time
+ mock_time.return_value = MOCK_START_TIME
listener = MockCheckpointSaverListener()
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir,
@@ -626,28 +630,28 @@
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
- mock_time.return_value = current_time + 0.5
+ mock_time.return_value = MOCK_START_TIME + 0.5
mon_sess.run(self.train_op) # hook runs here
- mock_time.return_value = current_time + 0.5
+ mock_time.return_value = MOCK_START_TIME + 0.5
mon_sess.run(self.train_op)
- mock_time.return_value = current_time + 3.0
+ mock_time.return_value = MOCK_START_TIME + 3.0
mon_sess.run(self.train_op) # hook runs here
- mock_time.return_value = current_time + 3.5
+ mock_time.return_value = MOCK_START_TIME + 3.5
mon_sess.run(self.train_op)
- mock_time.return_value = current_time + 4.0
+ mock_time.return_value = MOCK_START_TIME + 4.0
mon_sess.run(self.train_op)
- mock_time.return_value = current_time + 6.5
+ mock_time.return_value = MOCK_START_TIME + 6.5
mon_sess.run(self.train_op) # hook runs here
- mock_time.return_value = current_time + 7.0
+ mock_time.return_value = MOCK_START_TIME + 7.0
mon_sess.run(self.train_op) # hook won't run here, so it does at end
- mock_time.return_value = current_time + 7.5
+ mock_time.return_value = MOCK_START_TIME + 7.5
hook.end(sess) # hook runs here
self.assertEqual({
'begin': 1,
@@ -913,7 +917,9 @@
def tearDown(self):
shutil.rmtree(self.log_dir, ignore_errors=True)
- def test_step_counter_every_n_steps(self):
+ @test.mock.patch.object(time, 'time')
+ def test_step_counter_every_n_steps(self, mock_time):
+ mock_time.return_value = MOCK_START_TIME
with ops.Graph().as_default() as g, session_lib.Session() as sess:
variables.get_or_create_global_step()
train_op = training_util._increment_global_step(1)
@@ -925,7 +931,7 @@
mon_sess = monitored_session._HookedSession(sess, [hook])
with test.mock.patch.object(tf_logging, 'warning') as mock_log:
for _ in range(30):
- time.sleep(0.01)
+ mock_time.return_value += 0.01
mon_sess.run(train_op)
# logging.warning should not be called.
self.assertIsNone(mock_log.call_args)
@@ -941,7 +947,9 @@
self.assertEqual('global_step/sec', summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
- def test_step_counter_every_n_secs(self):
+ @test.mock.patch.object(time, 'time')
+ def test_step_counter_every_n_secs(self, mock_time):
+ mock_time.return_value = MOCK_START_TIME
with ops.Graph().as_default() as g, session_lib.Session() as sess:
variables.get_or_create_global_step()
train_op = training_util._increment_global_step(1)
@@ -953,9 +961,9 @@
self.evaluate(variables_lib.global_variables_initializer())
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(train_op)
- time.sleep(0.2)
+ mock_time.return_value += 0.2
mon_sess.run(train_op)
- time.sleep(0.2)
+ mock_time.return_value += 0.2
mon_sess.run(train_op)
hook.end(sess)
@@ -1037,13 +1045,15 @@
self.evaluate(variables_lib.global_variables_initializer())
self.mon_sess = monitored_session._HookedSession(sess, [self.hook])
- def test_steps_per_run_less_than_every_n_steps(self):
+ @test.mock.patch.object(time, 'time')
+ def test_steps_per_run_less_than_every_n_steps(self, mock_time):
+ mock_time.return_value = MOCK_START_TIME
with ops.Graph().as_default() as g, session_lib.Session() as sess:
self._setup_steps_per_run_test(10, 5, g, sess)
# Logs at 15, 25
for _ in range(5):
- time.sleep(0.01)
+ mock_time.return_value += 0.01
self.mon_sess.run(self.train_op)
self.hook.end(sess)
@@ -1058,13 +1068,15 @@
self.assertEqual('global_step/sec', summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
- def test_steps_per_run_equal_every_n_steps(self):
+ @test.mock.patch.object(time, 'time')
+ def test_steps_per_run_equal_every_n_steps(self, mock_time):
+ mock_time.return_value = MOCK_START_TIME
with ops.Graph().as_default() as g, session_lib.Session() as sess:
self._setup_steps_per_run_test(5, 5, g, sess)
# Logs at 10, 15, 20, 25
for _ in range(5):
- time.sleep(0.01)
+ mock_time.return_value += 0.01
self.mon_sess.run(self.train_op)
self.hook.end(sess)
@@ -1080,13 +1092,15 @@
self.assertEqual('global_step/sec', summary_value.tag)
self.assertGreater(summary_value.simple_value, 0)
- def test_steps_per_run_greater_than_every_n_steps(self):
+ @test.mock.patch.object(time, 'time')
+ def test_steps_per_run_greater_than_every_n_steps(self, mock_time):
+ mock_time.return_value = MOCK_START_TIME
with ops.Graph().as_default() as g, session_lib.Session() as sess:
self._setup_steps_per_run_test(5, 10, g, sess)
# Logs at 20, 30, 40, 50
for _ in range(5):
- time.sleep(0.01)
+ mock_time.return_value += 0.01
self.mon_sess.run(self.train_op)
self.hook.end(sess)
@@ -1199,7 +1213,9 @@
},
})
- def test_save_secs_saving_once_every_step(self):
+ @test.mock.patch.object(time, 'time')
+ def test_save_secs_saving_once_every_step(self, mock_time):
+ mock_time.return_value = MOCK_START_TIME
hook = basic_session_run_hooks.SummarySaverHook(
save_secs=0.5,
summary_writer=self.summary_writer,
@@ -1211,7 +1227,7 @@
mon_sess = monitored_session._HookedSession(sess, [hook])
for _ in range(4):
mon_sess.run(self.train_op)
- time.sleep(0.5)
+ mock_time.return_value += 0.5
hook.end(sess)
self.summary_writer.assert_summaries(
@@ -1279,27 +1295,43 @@
session_run_hook.SessionRunContext(
original_args=None, session=sess))
- def test_wait_for_step(self):
+ @test.mock.patch.object(time, 'sleep')
+ def test_wait_for_step(self, mock_sleep):
with ops.Graph().as_default():
gstep = variables.get_or_create_global_step()
hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
hook.begin()
+
with session_lib.Session() as sess:
+ # Mock out calls to time.sleep() to update the global step.
+
+ class Context(object):
+ counter = 0
+
+ def mock_sleep_side_effect(seconds):
+ del seconds # argument is ignored
+ Context.counter += 1
+ if Context.counter == 1:
+ # The first time sleep() is called, we update the global_step from
+ # 0 to 500.
+ sess.run(state_ops.assign(gstep, 500))
+ elif Context.counter == 2:
+ # The second time sleep() is called, we update the global_step from
+ # 500 to 1100.
+ sess.run(state_ops.assign(gstep, 1100))
+ else:
+ raise AssertionError(
+ 'Expected before_run() to terminate after the second call to '
+ 'time.sleep()')
+
+ mock_sleep.side_effect = mock_sleep_side_effect
+
+ # Run the mocked-out interaction with the hook.
self.evaluate(variables_lib.global_variables_initializer())
- waiter = threading.Thread(
- target=hook.before_run,
- args=(session_run_hook.SessionRunContext(
- original_args=None, session=sess),))
- waiter.daemon = True
- waiter.start()
- time.sleep(1.0)
- self.assertTrue(waiter.is_alive())
- sess.run(state_ops.assign(gstep, 500))
- time.sleep(1.0)
- self.assertTrue(waiter.is_alive())
- sess.run(state_ops.assign(gstep, 1100))
- time.sleep(1.2)
- self.assertFalse(waiter.is_alive())
+ run_context = session_run_hook.SessionRunContext(
+ original_args=None, session=sess)
+ hook.before_run(run_context)
+ self.assertEqual(Context.counter, 2)
class FinalOpsHookTest(test.TestCase):
@@ -1465,29 +1497,27 @@
@test.mock.patch.object(time, 'time')
def test_save_secs_saves_periodically(self, mock_time):
# Pick a fixed start time.
- current_time = 1484863632.
-
with self.graph.as_default():
- mock_time.return_value = current_time
+ mock_time.return_value = MOCK_START_TIME
hook = basic_session_run_hooks.ProfilerHook(
save_secs=2, output_dir=self.output_dir)
with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
sess.run(self.train_op) # Not saved.
self.assertEqual(0, self._count_timeline_files())
# Simulate 2.5 seconds of sleep.
- mock_time.return_value = current_time + 2.5
+ mock_time.return_value = MOCK_START_TIME + 2.5
sess.run(self.train_op) # Saved.
self.assertEqual(1, self._count_timeline_files())
# Pretend some small amount of time has passed.
- mock_time.return_value = current_time + 2.6
+ mock_time.return_value = MOCK_START_TIME + 2.6
sess.run(self.train_op) # Not saved.
# Edge test just before we should save the timeline.
- mock_time.return_value = current_time + 4.4
+ mock_time.return_value = MOCK_START_TIME + 4.4
sess.run(self.train_op) # Not saved.
self.assertEqual(1, self._count_timeline_files())
- mock_time.return_value = current_time + 4.5
+ mock_time.return_value = MOCK_START_TIME + 4.5
sess.run(self.train_op) # Saved.
self.assertEqual(2, self._count_timeline_files())
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index d26932c..f97f42a 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -152,7 +152,7 @@
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
- "//tensorflow/python/eager:function",
+ "//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras:engine",
"//tensorflow/python/keras:layers",
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index c29e5db..817552f 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -111,9 +111,6 @@
"""Base class for data structures which contain checkpointable objects."""
def __init__(self):
- # An append-only ordered set
- self._layers = []
-
self.trainable = True
self._extra_variables = []
@@ -128,22 +125,31 @@
("Only checkpointable objects (such as Layers or Optimizers) may be "
"stored in a List object. Got %s, which does not inherit from "
"CheckpointableBase.") % (value,))
- if (isinstance(value, CheckpointableDataStructure)
- or layer_utils.is_layer(value)
- or layer_utils.has_weights(value)):
- # Check for object-identity rather than with __eq__ to avoid
- # de-duplicating empty container types. Automatically generated list
- # wrappers keep things like "[] == []" true, which means "[] in [[]]" is
- # also true. This becomes not true once one of the lists is mutated.
- if not any((layer is value for layer in self._layers)):
- self._layers.append(value)
- if hasattr(value, "_use_resource_variables"):
- # In subclassed models, legacy layers (tf.layers) must always use
- # resource variables.
- value._use_resource_variables = True # pylint: disable=protected-access
+ if hasattr(value, "_use_resource_variables"):
+ # In subclassed models, legacy layers (tf.layers) must always use
+ # resource variables.
+ value._use_resource_variables = True # pylint: disable=protected-access
return value
@property
+ def _values(self):
+ """An iterable/sequence which may contain checkpointable objects."""
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def _layers(self):
+ """All Layers and Layer containers, including empty containers."""
+ # Filter objects on demand so that wrapper objects use values from the thing
+ # they're wrapping if out of sync.
+ collected = []
+ for obj in self._values:
+ if (isinstance(obj, CheckpointableDataStructure)
+ or layer_utils.is_layer(obj)
+ or layer_utils.has_weights(obj)):
+ collected.append(obj)
+ return collected
+
+ @property
def layers(self):
return layer_utils.filter_empty_layer_containers(self._layers)
@@ -265,6 +271,10 @@
def _name_element(self, index):
return "%d" % (index,)
+ @property
+ def _values(self):
+ return self
+
def append(self, value):
"""Add a new checkpointable value."""
value = self._track_value(value, self._name_element(len(self._storage)))
@@ -479,6 +489,14 @@
def _make_storage(self, *args, **kwargs):
return dict(*args, **kwargs)
+ @property
+ def _values(self):
+ # Sort items deterministically by key
+ ordered = list(zip(*sorted(self.items(), key=lambda it: it[0])))
+ if ordered:
+ return ordered[1]
+ return []
+
def _name_element(self, key):
if not isinstance(key, six.string_types):
raise TypeError(
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index ff7d1f1..9cefd94 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -253,6 +253,13 @@
l.append(1)
self.assertEqual([1], l_wrapper)
+ def testLayerCollectionWithExternalMutation(self):
+ l = []
+ l_wrapper = data_structures._ListWrapper(l)
+ layer = core.Dense(1)
+ l.append(layer)
+ self.assertEqual([layer], l_wrapper.layers)
+
def testHashing(self):
has_sequences = set([data_structures.List(),
data_structures.List()])
@@ -324,6 +331,20 @@
with self.assertRaises(TypeError):
mapping[1] = data_structures.List()
+ def testLayerCollectionWithExternalMutation(self):
+ d = {}
+ root = tracking.Checkpointable()
+ root.wrapper = d
+ self.assertEqual([], root.wrapper.layers)
+ self.assertEqual([], root.wrapper.trainable_weights)
+ layer1 = core.Dense(1)
+ layer2 = core.Dense(1)
+ d["a"] = layer1
+ d["b"] = layer2
+ self.assertEqual([layer1, layer2], root.wrapper.layers)
+ # The layers have still not created variables
+ self.assertEqual([], root.wrapper.trainable_weights)
+
def testHashing(self):
has_mappings = set([data_structures.Mapping(),
data_structures.Mapping()])
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index f45f744..394cc33 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -31,6 +31,7 @@
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
@@ -549,13 +550,11 @@
return slot_variables
-def _serialize_checkpointables(
- checkpointable_objects, node_ids, object_names, slot_variables,
+def _add_attributes_to_object_graph(
+ checkpointable_objects, object_graph_proto, node_ids, object_names,
saveables_cache, object_map):
- """Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
- object_graph_proto = (
- checkpointable_object_graph_pb2.CheckpointableObjectGraph())
- named_saveables = []
+ """Create SaveableObjects and corresponding SerializedTensor protos."""
+ named_saveable_objects = []
if saveables_cache is None:
# No SaveableObject caching. Either we're executing eagerly, or building a
# static save which is specialized to the current Python state.
@@ -564,10 +563,9 @@
# If we are caching SaveableObjects, we need to build up a feed_dict with
# functions computing volatile Python state to be saved with the checkpoint.
feed_additions = {}
- for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
+ for checkpoint_id, (checkpointable, object_proto) in enumerate(
+ zip(checkpointable_objects, object_graph_proto.nodes)):
assert node_ids[checkpointable] == checkpoint_id
- object_proto = object_graph_proto.nodes.add()
- object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
object_name = object_names[checkpointable]
if object_map:
object_to_save = object_map.get(checkpointable, checkpointable)
@@ -645,14 +643,24 @@
"value.")
% (checkpointable, new_feed_key))
feed_additions.update(saveable_feed_dict)
- named_saveables.append(saveable)
+ named_saveable_objects.append(saveable)
+ return named_saveable_objects, feed_additions
+
+
+def _make_object_graph_proto(checkpointable_objects, node_ids, slot_variables):
+ """Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
+ object_graph_proto = (
+ checkpointable_object_graph_pb2.CheckpointableObjectGraph())
+ for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
+ assert node_ids[checkpointable] == checkpoint_id
+ object_proto = object_graph_proto.nodes.add()
+ object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
child_proto.node_id = node_ids[child.ref]
child_proto.local_name = child.name
-
- return named_saveables, object_graph_proto, feed_additions
+ return object_graph_proto
def _serialize_gathered_objects(
@@ -668,13 +676,18 @@
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
object_names=object_names)
- return _serialize_checkpointables(
+ object_graph_proto = _make_object_graph_proto(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
+ slot_variables=slot_variables)
+ named_saveable_objects, feed_additions = _add_attributes_to_object_graph(
+ checkpointable_objects=checkpointable_objects,
+ object_graph_proto=object_graph_proto,
+ node_ids=node_ids,
object_names=object_names,
- slot_variables=slot_variables,
saveables_cache=saveables_cache,
object_map=object_map)
+ return named_saveable_objects, object_graph_proto, feed_additions
def _serialize_object_graph(root_checkpointable, saveables_cache):
@@ -716,6 +729,23 @@
return _serialize_object_graph(root_checkpointable, None)[0]
+def _find_objects(root_checkpointable):
+ """Find and number objects which are dependencies of `root_checkpointable`."""
+ checkpointable_objects, path_to_root = (
+ _breadth_first_checkpointable_traversal(root_checkpointable))
+ object_names = _ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = _ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
+ slot_variables = _serialize_slot_variables(
+ checkpointable_objects=checkpointable_objects,
+ node_ids=node_ids,
+ object_names=object_names)
+ return checkpointable_objects, node_ids, slot_variables
+
+
def list_objects(root_checkpointable):
"""Traverse the object graph and list all accessible objects.
@@ -730,23 +760,18 @@
Returns:
A flat list of objects.
"""
- # TODO(allenl): Extract out gathering logic so the naming logic doesn't have
- # to run.
- checkpointable_objects, path_to_root = (
- _breadth_first_checkpointable_traversal(root_checkpointable))
- object_names = _ObjectIdentityDictionary()
- for obj, path in path_to_root.items():
- object_names[obj] = _object_prefix_from_path(path)
- node_ids = _ObjectIdentityDictionary()
- for node_id, node in enumerate(checkpointable_objects):
- node_ids[node] = node_id
- _serialize_slot_variables(
- checkpointable_objects=checkpointable_objects,
- node_ids=node_ids,
- object_names=object_names)
+ checkpointable_objects, _, _ = _find_objects(root_checkpointable)
return checkpointable_objects
+def make_object_graph_without_attributes(root_checkpointable):
+ """Construct a CheckpointableObjectGraph proto with no variable values."""
+ checkpointable_objects, node_ids, slot_variables = _find_objects(
+ root_checkpointable)
+ return _make_object_graph_proto(
+ checkpointable_objects, node_ids, slot_variables)
+
+
def gather_initializers(root_checkpointable):
"""Traverse the object graph and find initialization ops.
@@ -1434,6 +1459,7 @@
elif session is None:
session = ops.get_default_session()
+ file_io.recursive_create_dir(os.path.dirname(file_prefix))
with ops.device("/cpu:0"):
save_path = saver.save(
sess=_SessionWithFeedDictAdditions(
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 1995514..de9cac0 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -26,7 +26,7 @@
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
-from tensorflow.python.eager import function
+from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -44,6 +44,7 @@
from tensorflow.python.ops import variables
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import momentum
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base
@@ -198,6 +199,17 @@
with self.assertRaises(NotImplementedError):
checkpoint_reversed.save(prefix)
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def test_object_graph_no_attributes(self):
+ root = tracking.Checkpointable()
+ root.v = resource_variable_ops.ResourceVariable(1.)
+ root.opt = momentum.MomentumOptimizer(0.01, 0.5)
+ root.opt.minimize(root.v.read_value)
+ object_graph = checkpointable_utils.make_object_graph_without_attributes(
+ root)
+ # Four objects: Root, v, opt, and a slot variable for v
+ self.assertEqual(4, len(object_graph.nodes))
+
class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
@@ -632,7 +644,7 @@
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
- @function.defun
+ @def_function.function
def _call_model(x):
return model(x)
with backprop.GradientTape() as tape:
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index 0ff97d8..b7e5c98 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -408,7 +408,7 @@
# Threads for the standard services.
-@tf_export("train.LooperThread")
+@tf_export(v1=["train.LooperThread"])
class LooperThread(threading.Thread):
"""A thread that runs code repeatedly, optionally on a timer.
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 4ef784d..ad27bc8 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -12,1621 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Library for running a computation across multiple devices."""
+"""Deprecated, please use ../distribute/distribute_lib.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import copy
-import threading
-import weakref
-import enum
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.distribute import reduce_util
-from tensorflow.python.eager import context as eager_context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops.losses import losses_impl
-from tensorflow.python.platform import tf_logging
-from tensorflow.python.training import device_util
-from tensorflow.python.training import distribution_strategy_context
-from tensorflow.python.util import nest
-from tensorflow.python.util.tf_export import tf_export
-from tensorflow.tools.docs import doc_controls
-
-
-# ------------------------------------------------------------------------------
-# Context tracking whether in a strategy.update() or .update_non_slot() call.
-
-
-_update_device = threading.local()
-
-
-def get_update_device():
- """Get the current device if in a `tf.distribute.Strategy.update()` call."""
- try:
- return _update_device.current
- except AttributeError:
- return None
-
-
-class UpdateContext(object):
- """Context manager when you are in `update()` or `update_non_slot()`."""
-
- def __init__(self, device):
- self._device = device
- self._old_device = None
-
- def __enter__(self):
- self._old_device = get_update_device()
- _update_device.current = self._device
-
- def __exit__(self, exception_type, exception_value, traceback):
- del exception_type, exception_value, traceback
- _update_device.current = self._old_device
-
-
-# ------------------------------------------------------------------------------
-# Public utility functions.
-
-
-@tf_export("distribute.get_loss_reduction")
-def get_loss_reduction():
- """`tf.distribute.ReduceOp` corresponding to the last loss reduction."""
- loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
- if loss_reduction == losses_impl.Reduction.SUM:
- return reduce_util.ReduceOp.SUM
- return reduce_util.ReduceOp.MEAN
-
-
-# ------------------------------------------------------------------------------
-# Internal API for validating the current thread mode
-
-
-def _require_cross_replica_context_extended(extended):
- """Verify in cross-replica context."""
- context = _get_per_thread_mode()
- cross_replica = context.cross_replica_context
- if cross_replica is not None and cross_replica.extended is extended:
- return
- strategy = extended._container_strategy() # pylint: disable=protected-access
- # We have an error to report, figure out the right message.
- if context.distribution_strategy is not strategy:
- _wrong_strategy_scope(strategy, context)
- assert cross_replica is None
- raise RuntimeError("Method requires being in cross-replica context, use "
- "get_replica_context().merge_call()")
-
-
-def _wrong_strategy_scope(strategy, context):
- # Figure out the right error message.
- if not distribution_strategy_context.has_distribution_strategy():
- raise RuntimeError(
- 'Need to be inside "with strategy.scope()" for %s' %
- (strategy,))
- else:
- raise RuntimeError(
- "Mixing different tf.distribute.Strategy objects: %s is not %s" %
- (context.distribution_strategy, strategy))
-
-
-def require_replica_context(replica_ctx):
- """Verify in `replica_ctx` replica context."""
- context = _get_per_thread_mode()
- if context.replica_context is replica_ctx: return
- # We have an error to report, figure out the right message.
- if context.replica_context is None:
- raise RuntimeError("Need to be inside `call_for_each_replica()`")
- if context.distribution_strategy is replica_ctx.distribution_strategy:
- # Two different ReplicaContexts with the same tf.distribute.Strategy.
- raise RuntimeError("Mismatching ReplicaContext.")
- raise RuntimeError(
- "Mismatching tf.distribute.Strategy objects: %s is not %s." %
- (context.distribution_strategy, replica_ctx.distribution_strategy))
-
-
-def _require_distribution_strategy_scope_strategy(strategy):
- """Verify in a `strategy.scope()` in this thread."""
- context = _get_per_thread_mode()
- if context.distribution_strategy is strategy: return
- _wrong_strategy_scope(strategy, context)
-
-
-def _require_distribution_strategy_scope_extended(extended):
- """Verify in a `distribution_strategy.scope()` in this thread."""
- context = _get_per_thread_mode()
- if context.distribution_strategy.extended is extended: return
- # Report error.
- strategy = extended._container_strategy() # pylint: disable=protected-access
- _wrong_strategy_scope(strategy, context)
-
-
-# ------------------------------------------------------------------------------
-# Internal context managers used to implement the DistributionStrategy
-# base class
-
-
-class _CurrentDistributionContext(object):
- """Context manager setting the current `tf.distribute.Strategy`.
-
- Also: overrides the variable creator and optionally the current device.
- """
-
- def __init__(self,
- strategy,
- var_creator_scope,
- var_scope=None,
- default_device=None):
- self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
- strategy)
- self._var_creator_scope = var_creator_scope
- self._var_scope = var_scope
- if default_device:
- self._device_scope = ops.device(default_device)
- else:
- self._device_scope = None
-
- def __enter__(self):
- _push_per_thread_mode(self._context)
- if self._var_scope:
- self._var_scope.__enter__()
- self._var_creator_scope.__enter__()
- if self._device_scope:
- self._device_scope.__enter__()
- return self._context.distribution_strategy
-
- def __exit__(self, exception_type, exception_value, traceback):
- if self._device_scope:
- self._device_scope.__exit__(exception_type, exception_value, traceback)
- self._var_creator_scope.__exit__(exception_type, exception_value, traceback)
- if self._var_scope:
- self._var_scope.__exit__(exception_type, exception_value, traceback)
- _pop_per_thread_mode()
-
-
-class _SameScopeAgainContext(object):
- """Trivial context manager when you are already in `scope()`."""
-
- def __init__(self, strategy):
- self._distribution_strategy = strategy
-
- def __enter__(self):
- return self._distribution_strategy
-
- def __exit__(self, exception_type, exception_value, traceback):
- del exception_type, exception_value, traceback
-
-
-# TODO(yuefengz): add more replication modes.
-@tf_export("distribute.InputReplicationMode")
-class InputReplicationMode(enum.Enum):
- """Replication mode for input function."""
-
- # The input function will be called on each worker independently, creating as
- # many input pipelines as number of workers. Replicas will dequeue from the
- # local Dataset on their worker. Distribution Strategy doesn't manage any
- # state sharing between such separate input pipelines.
- PER_WORKER = "PER_WORKER"
-
-
-@tf_export("distribute.InputContext")
-class InputContext(object):
- """A class wrapping information needed by an input function.
-
- This is a context class that is passed to the user's input fn and contains
- information about the compute replicas and input pipelines. The number of
- compute replicas (in sync training) helps compute per input pipeline batch
- size from the desired global batch size. Input pipeline information can be
- used to return a different subset of the input in each input pipeline (for
- e.g. shard the input pipeline, use a different input source etc).
- """
-
- def __init__(self,
- num_input_pipelines=1,
- input_pipeline_id=0,
- num_replicas_in_sync=1):
- """Initializes an InputContext object.
-
- Args:
- num_input_pipelines: the number of input pipelines in a cluster.
- input_pipeline_id: the current input pipeline id, should be an int in
- [0,`num_input_pipelines`).
- num_replicas_in_sync: the number of replicas that are in sync.
- """
- self._num_input_pipelines = num_input_pipelines
- self._input_pipeline_id = input_pipeline_id
- self._num_replicas_in_sync = num_replicas_in_sync
-
- @property
- def num_replicas_in_sync(self):
- """Returns the number of compute replicas in sync."""
- return self._num_replicas_in_sync
-
- @property
- def input_pipeline_id(self):
- """Returns the input pipeline ID."""
- return self._input_pipeline_id
-
- @property
- def num_input_pipelines(self):
- """Returns the number of input pipelines."""
- return self._num_input_pipelines
-
- def get_per_replica_batch_size(self, global_batch_size):
- """Returns the per-replica batch size.
-
- Args:
- global_batch_size: the global batch size which should be divisible by
- `num_replicas_in_sync`.
-
- Returns:
- the per-replica batch size.
-
- Raises:
- ValueError: if `global_batch_size` not divisible by
- `num_replicas_in_sync`.
- """
- if global_batch_size % self._num_replicas_in_sync != 0:
- raise ValueError("The `global_batch_size` %r is not divisible by "
- "`num_replicas_in_sync` %r " %
- (global_batch_size, self._num_replicas_in_sync))
- return global_batch_size // self._num_replicas_in_sync
-
-
-# ------------------------------------------------------------------------------
-# Base classes for all distribution strategies.
-
-
-@tf_export("distribute.Strategy")
-class DistributionStrategy(object):
- """A list of devices with a state & compute distribution policy.
-
- See [tensorflow/contrib/distribute/README.md](
- https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md)
- for overview and examples.
- """
-
- # TODO(josh11b): Raise an exception if variable partitioning requested before
- # we add support.
- # TODO(josh11b): Also `parameter_device_index` property?
- # TODO(josh11b): `map()`
- # TODO(josh11b): ClusterSpec/ClusterResolver
- # TODO(josh11b): Partitioned computations, state; sharding
- # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling
- # TODO(josh11b): List of replicas with their worker and parameter devices
- # (where the parameter devices may overlap in the ps case).
-
- def __init__(self, extended):
- self._extended = extended
-
- @property
- def extended(self):
- """`tf.distribute.StrategyExtended` with additional methods."""
- return self._extended
-
- def scope(self):
- """Returns a context manager selecting this Strategy as current.
-
- Inside a `with strategy.scope():` code block, this thread
- will use a variable creator set by `strategy`, and will
- enter its "cross-replica context".
-
- Returns:
- A context manager.
- """
- return self._extended._scope(self) # pylint: disable=protected-access
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def read_var(self, v):
- """DEPRECATED: use extended.read_var() instead."""
- return self._extended.read_var(v)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def colocate_vars_with(self, colocate_with_variable):
- """DEPRECATED: use extended.colocate_vars_with() instead."""
- return self._extended.colocate_vars_with(colocate_with_variable)
-
- @doc_controls.do_not_generate_docs # DEPRECATED
- def distribute_dataset(self, dataset_fn):
- """Return a `dataset` split across all replicas. DEPRECATED.
-
- DEPRECATED: Please use `make_dataset_iterator` or
- `make_input_fn_iterator` instead.
-
- Suitable for providing input to `extended.call_for_each_replica()` by
- creating an iterator:
-
- ```
- def dataset_fn():
- return tf.data.Dataset.from_tensors([[1.]]).repeat()
-
- with strategy.scope():
- distributed_dataset = strategy.distribute_dataset(dataset_fn)
- iterator = distributed_dataset.make_initializable_iterator()
- replica_results = strategy.extended.call_for_each_replica(
- replica_fn, args=(iterator.get_next(),))
- ```
-
- Args:
- dataset_fn: A function that returns a `tf.data.Dataset`.
-
- Returns:
- A `PerReplicaDataset` that will produce data for each replica.
- """
- return self._extended._distribute_dataset(dataset_fn) # pylint: disable=protected-access
-
- def make_dataset_iterator(self, dataset):
- """Makes an iterator for input provided via input_dataset.
-
- Data from the given dataset will be distributed evenly across all the
- compute replicas. We will assume that the input dataset is batched by the
- global batch size. With this assumption, we will make a best effort to
- divide each batch across all the replicas (one or more workers).
- If this effort fails, an error will be thrown, and the user should instead
- use `make_input_fn_iterator` which provides more control to the user, and
- does not try to divide a batch across replicas.
-
- The user could also use `make_input_fn_iterator` if they want to
- customize which input is fed to which replica/worker etc.
-
- Args:
- dataset: `tf.data.Dataset` that will be distributed evenly across all
- replicas.
-
- Returns:
- An `tf.distribute.InputIterator` which returns inputs for each step of the
- computation. User should call `initialize` on the returned iterator.
- """
- return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
-
- def make_input_fn_iterator(self,
- input_fn,
- replication_mode=InputReplicationMode.PER_WORKER):
- """Returns an iterator split across replicas created from an input function.
-
- The `input_fn` should take an `tf.distribute.InputContext` object where
- information about input sharding can be accessed:
-
- ```
- def input_fn(input_context):
- d = tf.data.Dataset.from_tensors([[1.]]).repeat()
- return d.shard(input_context.num_input_pipelines,
- input_context.input_pipeline_id)
- with strategy.scope():
- iterator = strategy.make_input_fn_iterator(
- input_fn)
- replica_results = strategy.extended.call_for_each_replica(
- replica_fn, iterator.get_next())
- ```
-
- Args:
- input_fn: A function that returns a `tf.data.Dataset`. This function is
- expected to take an `tf.distribute.InputContext` object.
- replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
- Only `PER_WORKER` is supported currently.
-
- Returns:
- An iterator object that can be initialized and fetched next element.
- """
- if replication_mode != InputReplicationMode.PER_WORKER:
- raise ValueError(
- "Input replication mode not supported: %r" % replication_mode)
- return self.extended._make_input_fn_iterator( # pylint: disable=protected-access
- input_fn, replication_mode=replication_mode)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def broadcast(self, tensor, destinations=None):
- """DEPRECATED: use extended.broadcast_to() instead."""
- return self._extended.broadcast_to(tensor, destinations)
-
- @doc_controls.do_not_generate_docs # Use experimental_initialize() instead.
- def initialize(self):
- """DEPRECATED: Use `experimental_initialize()` instead."""
- return self._extended._initialize() # pylint: disable=protected-access
-
- def experimental_initialize(self):
- """Any initialization to be done before running any computations.
-
- In eager mode, it executes any initialization as a side effect.
- In graph mode, it creates the initialization ops and returns them.
-
- For example, TPU initialize_system ops.
-
- Returns:
- A list of ops to execute.
- """
- return self._extended._initialize() # pylint: disable=protected-access
-
- @doc_controls.do_not_generate_docs # Use experimental_finalize() instead.
- def finalize(self):
- """DEPRECATED: Use `experimental_finalize()` instead."""
- return self._extended._finalize() # pylint: disable=protected-access
-
- def experimental_finalize(self):
- """Any final actions to be done at the end of all computations.
-
- In eager mode, it executes any finalize actions as a side effect.
- In graph mode, it creates the finalize ops and returns them.
-
- For example, TPU shutdown ops.
-
- Returns:
- A list of ops to execute.
- """
- return self._extended._finalize() # pylint: disable=protected-access
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def run_steps_on_dataset(self, fn, iterator, iterations=1,
- initial_loop_values=None):
- """DEPRECATED: use extended.experimental_run_steps_on_iterator() instead."""
- return self._extended.experimental_run_steps_on_iterator(
- fn, iterator, iterations, initial_loop_values)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def call_for_each_replica(self, fn, *args, **kwargs):
- """DEPRECATED: use extended.call_for_each_replica() instead."""
- # Handle old *args, **kwargs, and new args=(...), kwargs={...}, to
- # allow transition.
- a = kwargs.pop("args", None)
- if a is not None:
- if args:
- raise ValueError(
- "Can't pass *args and args=... to call_for_each_replica")
- args = a
- k = kwargs.pop("kwargs", None)
- if k is not None:
- if kwargs:
- raise ValueError(
- "Can't pass **kwargs and kwargs=... to call_for_each_replica")
- kwargs = k
- kwargs.pop("run_concurrently", None) # Ignore old option.
- return self._extended.call_for_each_replica(fn, args, kwargs)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def reduce(self, aggregation, value, destinations):
- """DEPRECATED: use extended.reduce_to() instead."""
- return self._extended.reduce_to(aggregation, value, destinations)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def batch_reduce(self, aggregation, value_destination_pairs):
- """DEPRECATED: use extended.batch_reduce_to() instead."""
- return self._extended.batch_reduce_to(aggregation, value_destination_pairs)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def update(self, var, fn, *args, **kwargs):
- """DEPRECATED: use extended.update() instead."""
- group = kwargs.pop("group", True)
- # We temporarily support "grouped" in addition to "group" for backward-
- # compatibility.
- group = kwargs.pop("grouped", True) and group
- # Handle old *args, **kwargs, and new args=(...), kwargs={...}, to
- # allow transition.
- a = kwargs.pop("args", None)
- if a is not None:
- if args:
- raise ValueError(
- "Can't pass *args and args=... to update")
- args = a
- k = kwargs.pop("kwargs", None)
- if k is not None:
- if kwargs:
- raise ValueError(
- "Can't pass **kwargs and kwargs=... to update")
- kwargs = k
- return self._extended.update(var, fn, args, kwargs, group)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def update_non_slot(self, colocate_with, fn, *args, **kwargs):
- """DEPRECATED: use extended.update_non_slot() instead."""
- group = kwargs.pop("group", True)
- # We temporarily support "grouped" in addition to "group" for backward-
- # compatibility.
- group = kwargs.pop("grouped", True) and group
- # Handle old *args, **kwargs, and new args=(...), kwargs={...}, to
- # allow transition.
- a = kwargs.pop("args", None)
- if a is not None:
- if args:
- raise ValueError(
- "Can't pass *args and args=... to update_non_slot")
- args = a
- k = kwargs.pop("kwargs", None)
- if k is not None:
- if kwargs:
- raise ValueError(
- "Can't pass **kwargs and kwargs=... to update_non_slot")
- kwargs = k
- return self._extended.update_non_slot(
- colocate_with, fn, args, kwargs, group)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, -> `DistributedValues`
- def unwrap(self, value):
- """Returns the list of all per-replica values contained in `value`.
-
- Args:
- value: A value returned by `extended.call_for_each_replica()` or a
- variable created in `scope`.
-
- Returns:
- A list of values contained in `value`. If `value` represents a single
- value, this returns `[value].`
- """
- return self._extended._unwrap(value) # pylint: disable=protected-access
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def value_container(self, value):
- """DEPRECATED: use extended.value_container() instead."""
- return self._extended.value_container(value)
-
- @doc_controls.do_not_generate_docs # DEPRECATED, -> `DistributedValues`
- def group(self, value, name=None):
- """Shortcut for `tf.group(self.unwrap(value))`."""
- return self._extended._group(value, name) # pylint: disable=protected-access
-
- @property
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def require_static_shapes(self):
- """DEPRECATED: use extended.require_static_shapes instead."""
- return self._extended.experimental_require_static_shapes
-
- @property
- def num_replicas_in_sync(self):
- """Returns number of replicas over which gradients are aggregated."""
- return self._extended._num_replicas_in_sync # pylint: disable=protected-access
-
- @property
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def worker_devices(self):
- """DEPRECATED: use extended.worker_devices instead."""
- return self._extended.worker_devices
-
- @property
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def parameter_devices(self):
- """DEPRECATED: use extended.parameter_devices instead."""
- return self._extended.parameter_devices
-
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def non_slot_devices(self, var_list):
- """DEPRECATED: use extended.non_slot_devices instead."""
- return self._extended.non_slot_devices(var_list)
-
- @property
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def between_graph(self):
- """DEPRECATED: use extended.experimental_between_graph instead."""
- return self._extended.experimental_between_graph
-
- @doc_controls.do_not_generate_docs # DEPRECATED, being replaced by a new API.
- def configure(self,
- session_config=None,
- cluster_spec=None,
- task_type=None,
- task_id=None):
- """Configures the strategy class."""
- return self._extended._configure( # pylint: disable=protected-access
- session_config, cluster_spec, task_type, task_id)
-
- @property
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def should_init(self):
- """DEPRECATED: use extended.should_init instead."""
- return self._extended.experimental_should_init
-
- @property
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def should_checkpoint(self):
- """DEPRECATED: use extended.should_checkpoint instead."""
- return self._extended.should_checkpoint
-
- @property
- @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended`
- def should_save_summary(self):
- """DEPRECATED: use extended.should_save_summary instead."""
- return self._extended.should_save_summary
-
- def __deepcopy__(self, memo):
- # First do a regular deepcopy of `self`.
- cls = self.__class__
- result = cls.__new__(cls)
- memo[id(self)] = result
- for k, v in self.__dict__.items():
- setattr(result, k, copy.deepcopy(v, memo))
- # One little fix-up: we want `result._extended` to reference `result`
- # instead of `self`.
- result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access
- return result
-
- def __copy__(self):
- raise RuntimeError("Must only deepcopy DistributionStrategy.")
-
-
-@tf_export("distribute.StrategyExtended")
-class DistributionStrategyExtended(object):
- """Additional APIs for algorithms that need to be distribution-aware.
-
- The intent is that you can write an algorithm in a stylized way and
- it will be usable with a variety of different
- `tf.distribute.Strategy`
- implementations. Each descendant will implement a different strategy
- for distributing the algorithm across multiple devices/machines.
- Furthermore, these changes can be hidden inside the specific layers
- and other library classes that need special treatment to run in a
- distributed setting, so that most users' model definition code can
- run unchanged. The `tf.distribute.Strategy` API works the same way
- with eager and graph execution.
-
- First let's introduce a few high-level concepts:
-
- * _Data parallelism_ is where we run multiple copies of the model
- on different slices of the input data. This is in contrast to
- _model parallelism_ where we divide up a single copy of a model
- across multiple devices.
- Note: we only support data parallelism for now, but
- hope to add support for model parallelism in the future.
- * A _replica_ is one copy of the model, running on one slice of the
- input data.
- * _Synchronous_, or more commonly _sync_, training is where the
- updates from each replica are aggregated together before updating
- the model variables. This is in contrast to _asynchronous_, or
- _async_ training, where each replica updates the model variables
- independently.
- * Furthermore you might run your computation on multiple devices
- on one machine (or "host"), or on multiple machines/hosts.
- If you are running on multiple machines, you might have a
- single master host that drives computation across all of them,
- or you might have multiple clients driving the computation
- asynchronously.
-
- To distribute an algorithm, we might use some of these ingredients:
-
- * Parameter servers: These are hosts that hold a single copy of
- parameters/variables. All replicas that want to operate on a variable
- retrieve it at the beginning of a step and send an update to be
- applied at the end of the step. Can support either sync or async
- training.
- * Mirrored variables: These are variables that are copied to multiple
- devices, where we keep the copies in sync by applying the same
- updates to every copy. Normally would only be used with sync training.
- * Reductions and Allreduce: A _reduction_ is some method of
- aggregating multiple values into one value, like "sum" or
- "mean". If doing sync training, we will perform a reduction on the
- gradients to a parameter from all replicas before applying the
- update. Allreduce is an algorithm for performing a reduction on
- values from multiple devices and making the result available on
- all of those devices.
- * In the future we will have support for TensorFlow's partitioned
- variables, where a single variable is split across multiple
- devices.
-
- We have then a few approaches we want to support:
-
- * Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
- This code should work as before, even if some of the layers, etc.
- used by that code are written to be distribution-aware. This is done
- by having a default `tf.distribute.Strategy` that gives ordinary behavior,
- and by default being in a single replica context.
- * Ordinary model code that you want to run using a specific
- `tf.distribute.Strategy`. This can be as simple as:
-
- ```
- with my_strategy.scope():
- iterator = my_strategy.make_dataset_iterator(dataset)
- session.run(iterator.initialize())
- replica_train_ops = my_strategy.extended.call_for_each_replica(
- replica_fn, args=(iterator.get_next(),))
- train_op = my_strategy.group(replica_train_ops)
- ```
-
- This takes an ordinary `dataset` and `replica_fn` and runs it
- distributed using a particular `tf.distribute.Strategy` in
- `my_strategy`. Any variables created in `replica_fn` are created
- using `my_strategy`'s policy, and library functions called by
- `replica_fn` can use the `get_replica_context()` API to get enhanced
- behavior in this case.
-
- * If you want to write a distributed algorithm, you may use any of
- the `tf.distribute.Strategy` APIs inside a
- `with my_strategy.scope():` block of code.
-
- Lower-level concepts:
-
- * Wrapped values: In order to represent values parallel across devices
- (either replicas or the devices associated with a particular value), we
- wrap them in a "PerReplica" or "Mirrored" object that contains a map
- from device to values. "PerReplica" is used when the value may be
- different across replicas, and "Mirrored" when the value are the same.
- * Unwrapping and merging: Consider calling a function `fn` on multiple
- replicas, like `extended.call_for_each_replica(fn, args=[w])` with an
- argument `w` that is a wrapped value. This means `w` will have a map taking
- replica device `d0` to `w0`, replica device `d1` to `w1`,
- etc. `extended.call_for_each_replica()` unwraps `w` before calling `fn`, so
- it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return
- values from `fn()`, which can possibly result in wrapped values. For
- example, let's say `fn()` returns a tuple with three components: `(x, a,
- v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component
- is the same object `x` from every replica, then the first component of the
- merged result will also be `x`. If the second component is different (`a`,
- `b`, ...) from each replica, then the merged value will have a wrapped map
- from replica device to the different values. If the third component is the
- members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.),
- then the merged result will be that mirrored variable (`v`).
- * Replica context vs. Cross-replica context: _replica context_ is when we
- are in some function that is being called once for each replica.
- Otherwise we are in cross-replica context, which is useful for
- calling `tf.distribute.Strategy` methods which operate across the
- replicas (like `reduce_to()`). By default you start in a replica context
- (the default "single replica context") and then some methods can
- switch you back and forth, as described below.
- * Worker devices vs. parameter devices: Most replica computations will
- happen on worker devices. Since we don't yet support model
- parallelism, there will be one worker device per replica. When using
- parameter servers (see above), the set of devices holding
- variables may be different, otherwise the parameter devices might
- match the worker devices.
- * Non-slot devices are some subset of the parameter devices where we
- put all the non-slot variables. We need to ensure that all
- non-slot variables are allocated on the same device, or mirrored
- across the same set of devices. If you have some variable you want
- to colocate all the non-slot variables with, you can use
- `colocate_vars_with()` to get the remaining non-slot variables on
- the same device. Otherwise you can use `non_slot_devices()` to
- pick a consistent set of devices to pass to both
- `colocate_vars_with()` and `update_non_slot()`.
-
- When using a `tf.distribute.Strategy`, we have a new type dimension
- called _locality_ that says what values are compatible with which
- APIs:
-
- * T: different value for each replica (e.g. a PerReplica-wrapped value).
- * M: value is "mirrored" across replicas, i.e. there are copies with the
- same value on each replica (e.g. a Mirrored-wrapped value).
- * V(`v`): value is "mirrored" across all the devices which have a
- copy of variable `v` (also a Mirrored-wrapped value, but over
- parameter devices instead of worker devices).
- * N: value is "mirrored" across all the "non-slot" devices
-
- Rules for methods with respect to locality and single-replica vs.
- cross-replica context:
-
- * `with d.scope()`: default single-replica context -> cross-replica context
- for `d`
- * `with d.extended.colocate_vars_with(v)`: in replica/cross-replica context,
- variables will be created with locality V(`v`). That is, if we write
- `with d.extended.colocate_vars_with(v1): v2 = tf.get_variable(...)`,
- then `v2` will have locality V(`v1`), i.e. locality V(`v2`) will equal
- V(`v1`).
- * `with d.extended.colocate_vars_with(d.extended.non_slot_devices(...))`: in
- replica/cross-replica context, variables will be created with locality N
- * `v = tf.get_variable(...)`: in replica/cross-replica context, creates
- a variable (which by definition will have locality V(`v`), though
- will match another locality if inside a `colocate_vars_with`
- scope).
- * `d.make_dataset_iterator(dataset)` (or the deprecated
- `d.distribute_dataset(dataset).make_one_shot_iterator()`): in cross-replica
- context, produces an iterator with locality T
- * `d.extended.broadcast_to(t)`: in cross-replica context, produces a value
- with locality M
- * `d.extended.broadcast_to(t, v)`: in cross-replica context, produces a value
- with locality V(`v`)
- * `d.extended.call_for_each_replica(fn, ...)`: in cross-replica context, runs
- `fn()` in a replica context (and so may call `get_replica_context()` and
- use its API, including `merge_call()` to get back to cross-replica
- context), once for each replica. May use values with locality T or
- M, and any variable.
- * `d.extended.reduce_to(m, t, t)`: in cross-replica context, accepts t with
- locality T and produces a value with locality M.
- * `d.extended.reduce_to(m, t, v)`: in cross-replica context, accepts t with
- locality T and produces a value with locality V(`v`).
- * `d.extended.batch_reduce_to(m, [(t, v)]): see `d.extended.reduce_to()`
- * `d.extended.update(v, fn, ...)`: in cross-replica context, runs `fn()` once
- for each device `v` is copied to, all inputs should have locality
- V(`v`), output will have locality V(`v`) as well.
- * `d.extended.update_non_slot(d.extended.non_slot_devices(), fn)`: in
- cross-replica context, like `d.extended.update()` except with locality N.
- * `d.extended.read_var(v)`: Gets the (read-only) value of the variable `v` (on
- the device determined by the current device scope), aggregating
- across replicas for replica-local variables. Frequently, this will be
- done automatically when using `v` in an expression or fetching it in
- a cross-replica context, but this function can be used to force that
- conversion happens at a particular point in time (for example, to
- add the result of the conversion to a graph collection).
-
- The standard pattern for updating variables is to:
-
- 1. Create an input iterator with `d.make_dataset_iterator()`.
- 2. Define each replica `d.extended.call_for_each_replica()` up to the point of
- getting a list of gradient, variable pairs.
- 3. Call `d.extended.reduce_to(VariableAggregation.SUM, t, v)` or
- `d.extended.batch_reduce_to()` to sum the gradients (with locality T)
- into values with locality V(`v`).
- 4. Call `d.extended.update(v)` for each variable to update its value.
-
- Steps 3 and 4 are done automatically by class `Optimizer` if you call
- its `apply_gradients` method in a replica context. Otherwise you can
- manually call its `_distributed_apply` method in a cross-replica context.
-
- Another thing you might want to do in the middle of your replica function is
- an all-reduce of some intermediate value, using `d.extended.reduce_to()` or
- `d.extended.batch_reduce_to()`. You simply provide the same tensor as the
- input and destination.
-
- Layers should expect to be called in a replica context, and can use
- the `tf.distribute.get_replica_context` function to get a
- `tf.distribute.ReplicaContext` object. The
- `ReplicaContext` object has a `merge_call()` method for entering
- cross-replica context where you can use `reduce_to()` (or
- `batch_reduce_to()`) and then optionally `update()` to update state.
-
- You may use this API whether or not a `tf.distribute.Strategy` is
- being used, since there is a default implementation of
- `ReplicaContext` and `tf.distribute.Strategy`.
-
- NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
- in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
- the `tf.distribute.Strategy` subclass is for instantiating your subclass of
- `tf.distribute.StrategyExtended` in the `__init__` method.
- """
-
- def __init__(self, container_strategy):
- self._container_strategy_weakref = weakref.ref(container_strategy)
- self._default_device = None
- # This property is used to determine if we should set drop_remainder=True
- # when creating Datasets from numpy array inputs.
- self._require_static_shapes = False
-
- def _container_strategy(self):
- """Get the containing `DistributionStrategy`.
-
- This should not generally be needed except when creating a new
- `ReplicaContext` and to validate that the caller is in the correct
- `scope()`.
-
- Returns:
- The `DistributionStrategy` such that `strategy.extended` is `self`.
- """
- container_strategy = self._container_strategy_weakref()
- assert container_strategy is not None
- return container_strategy
-
- def _scope(self, strategy):
- """Implementation of DistributionStrategy.scope()."""
- if distribution_strategy_context.has_distribution_strategy():
- _require_cross_replica_context_extended(self)
- return _SameScopeAgainContext(strategy)
-
- def creator_with_resource_vars(*args, **kwargs):
- _require_distribution_strategy_scope_extended(self)
- kwargs["use_resource"] = True
- return self._create_variable(*args, **kwargs)
-
- def distributed_getter(getter, *args, **kwargs):
- if not self._allow_variable_partition():
- if kwargs.pop("partitioner", None) is not None:
- tf_logging.log_first_n(
- tf_logging.WARN, "Partitioned variables are disabled when using "
- "current tf.distribute.Strategy.", 1)
- return getter(*args, **kwargs)
-
- return _CurrentDistributionContext(
- strategy,
- variable_scope.variable_creator_scope(creator_with_resource_vars),
- variable_scope.variable_scope(
- variable_scope.get_variable_scope(),
- custom_getter=distributed_getter), self._default_device)
-
- def _allow_variable_partition(self):
- return False
-
- def _create_variable(self, next_creator, *args, **kwargs):
- # Note: should support "colocate_with" argument.
- raise NotImplementedError("must be implemented in descendants")
-
- def read_var(self, v):
- """Reads the value of a variable.
-
- Returns the aggregate value of a replica-local variable, or the
- (read-only) value of any other variable.
-
- Args:
- v: A variable allocated within the scope of this `tf.distribute.Strategy`.
-
- Returns:
- A tensor representing the value of `v`, aggregated across replicas if
- necessary.
- """
- raise NotImplementedError("must be implemented in descendants")
-
- def colocate_vars_with(self, colocate_with_variable):
- """Scope that controls which devices variables will be created on.
-
- No operations should be added to the graph inside this scope, it
- should only be used when creating variables (some implementations
- work by changing variable creation, others work by using a
- tf.colocate_with() scope).
-
- This may only be used inside `self.scope()`.
-
- Example usage:
-
- ```
- with strategy.scope():
- var1 = tf.get_variable(...)
- with strategy.extended.colocate_vars_with(v1):
- # var2 and var3 will be created on the same device(s) as var1
- var2 = tf.get_variable(...)
- var3 = tf.get_variable(...)
-
- def fn(v1, v2, v3):
- # operates on v1 from var1, v2 from var2, and v3 from var3
-
- # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
- strategy.extended.update(v1, fn, args=(v2, v3))
- ```
-
- Args:
- colocate_with_variable: A created in `self.scope()`. Variables created
- while in the returned context manager will be on the same set of
- devices as `colocate_with_variable`.
-
- Returns:
- A context manager.
- """
- def create_colocated_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope_extended(self)
- kwargs["use_resource"] = True
- kwargs["colocate_with"] = colocate_with_variable
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope_extended(self)
- return variable_scope.variable_creator_scope(create_colocated_variable)
-
- def _call_dataset_fn(self, dataset_fn):
- """Call the `dataset_fn` with `input_context` as argument."""
- result = dataset_fn()
- if not isinstance(result, dataset_ops.Dataset):
- raise ValueError(
- "dataset_fn() must return a tf.data.Dataset when using a "
- "tf.distribute.Strategy.")
- return result
-
- # TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
- # Dataset API such as make_one_shot_iterator and make_initializable_iterator.
- # Extend to implement more functionality of datasets.
- def _distribute_dataset(self, dataset_fn):
- raise NotImplementedError("must be implemented in descendants")
-
- def _make_dataset_iterator(self, dataset):
- raise NotImplementedError("must be implemented in descendants")
-
- def _make_input_fn_iterator(self, input_fn, replication_mode):
- raise NotImplementedError("must be implemented in descendants")
-
- def broadcast_to(self, tensor, destinations):
- """Mirror a tensor on one device to all worker devices.
-
- Args:
- tensor: A Tensor value to broadcast.
- destinations: A mirrored variable, device string, or list of device
- strings, specifying the destination devices to copy `tensor` to.
-
- Returns:
- A value mirrored to `destinations` devices.
- """
- # TODO(josh11b): More docstring
- _require_cross_replica_context_extended(self)
- return self._broadcast_to(tensor, destinations)
-
- def _broadcast_to(self, tensor, destinations):
- raise NotImplementedError("must be implemented in descendants")
-
- def _initialize(self):
- return []
-
- def _finalize(self):
- return []
-
- def experimental_run_steps_on_iterator(self, fn, iterator, iterations=1,
- initial_loop_values=None):
- """Run `fn` with input from `iterator` for `iterations` times.
-
- This method can be used to run a step function for training a number of
- times using input from a dataset.
-
- Args:
- fn: function to run using this distribution strategy. The function must
- have the following signature: `def fn(context, inputs)`.
- `context` is an instance of `MultiStepContext` that will be passed when
- `fn` is run. `context` can be used to specify the outputs to be returned
- from `fn` by calling `context.set_last_step_output`. It can also be used
- to capture non tensor outputs by `context.set_non_tensor_output`.
- See `MultiStepContext` documentation for more information.
- `inputs` will have same type/structure as `iterator.get_next()`.
- Typically, `fn` will use `call_for_each_replica` method of the strategy
- to distribute the computation over multiple replicas.
- iterator: Iterator of a dataset that represents the input for `fn`. The
- caller is responsible for initializing the iterator as needed.
- iterations: (Optional) Number of iterations that `fn` should be run.
- Defaults to 1.
- initial_loop_values: (Optional) Initial values to be passed into the
- loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove
- initial_loop_values argument when we have a mechanism to infer the
- outputs of `fn`.
-
- Returns:
- Returns the `MultiStepContext` object which has the following properties,
- among other things:
- - run_op: An op that runs `fn` `iterations` times.
- - last_step_outputs: A dictionary containing tensors set using
- `context.set_last_step_output`. Evaluating this returns the value of
- the tensors after the last iteration.
- - non_tensor_outputs: A dictionatry containing anything that was set by
- `fn` by calling `context.set_non_tensor_output`.
- """
- _require_cross_replica_context_extended(self)
- return self._experimental_run_steps_on_iterator(
- fn, iterator, iterations, initial_loop_values)
-
- def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
- initial_loop_values):
- raise NotImplementedError("must be implemented in descendants")
-
- def call_for_each_replica(self, fn, args=(), kwargs=None):
- """Run `fn` once per replica.
-
- `fn` may call `tf.get_replica_context()` to access methods such as
- `replica_id_in_sync_group` and `merge_call()`.
-
- `merge_call()` is used to communicate between the replicas and
- re-enter the cross-replica context. All replicas pause their execution
- having encountered a `merge_call()` call. After that the
- `merge_fn`-function is executed. Its results are then unwrapped and
- given back to each replica call. After that execution resumes until
- `fn` is complete or encounters another `merge_call()`. Example:
-
- ```python
- # Called once in "cross-replica" context.
- def merge_fn(distribution, three_plus_replica_id):
- # sum the values across replicas
- return sum(distribution.unwrap(three_plus_replica_id))
-
- # Called once per replica in `distribution`, in a "replica" context.
- def fn(three):
- replica_ctx = tf.get_replica_context()
- v = three + replica_ctx.replica_id_in_sync_group
- # Computes the sum of the `v` values across all replicas.
- s = replica_ctx.merge_call(merge_fn, args=(v,))
- return s + v
-
- with distribution.scope():
- # in "cross-replica" context
- ...
- merged_results = distribution.call_for_each_replica(fn, args=[3])
- # merged_results has the values from every replica execution of `fn`.
- print(distribution.unwrap(merged_results)) # Prints a list
- ```
-
- Args:
- fn: function to run (will be run once per replica).
- args: Tuple or list with positional arguments for `fn`.
- kwargs: Dict with keyword arguments for `fn`.
-
- Returns:
- Merged return value of `fn` across all replicas.
- """
- _require_cross_replica_context_extended(self)
- if kwargs is None:
- kwargs = {}
- return self._call_for_each_replica(fn, args, kwargs)
-
- def _call_for_each_replica(self, fn, args, kwargs):
- raise NotImplementedError("must be implemented in descendants")
-
- def reduce_to(self, reduce_op, value, destinations):
- """Combine (via e.g. sum or mean) values across replicas.
-
- Args:
- reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
- DEPRECATED but still accepted values:
- `tf.VariableAggregation.SUM`,
- `tf.VariableAggregation.MEAN`,
- value: A per-replica value with one value per replica.
- destinations: A mirrored variable, a per-replica tensor, a device string,
- or list of device strings. The return value will be copied to all
- destination devices (or all the devices where the `destinations` value
- resides). To perform an all-reduction, pass `value` to `destinations`.
-
- Returns:
- A value mirrored to `destinations`.
- """
- # TODO(josh11b): More docstring
- # TODO(josh11b): Return an unwrapped value if colocate_with is a
- # single device.
- _require_cross_replica_context_extended(self)
-
- # TODO(priyag): Remove this when all callers have been updated.
- if isinstance(reduce_op, variable_scope.VariableAggregation):
- assert reduce_op in [
- variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN,
- ]
- reduce_op = reduce_util.ReduceOp.from_variable_aggregation(reduce_op)
- return self._reduce_to(reduce_op, value, destinations)
-
- def _reduce_to(self, reduce_op, value, destinations):
- raise NotImplementedError("must be implemented in descendants")
-
- def batch_reduce_to(self, reduce_op, value_destination_pairs):
- """Combine multiple `reduce_to` calls into one for faster execution.
-
- Args:
- reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
- DEPRECATED but still accepted values:
- `tf.VariableAggregation.SUM`,
- `tf.VariableAggregation.MEAN`,
- value_destination_pairs: A sequence of (value, destinations)
- pairs. See `reduce_to()` for a description.
-
- Returns:
- A list of mirrored values, one per pair in `value_destination_pairs`.
- """
- # TODO(josh11b): More docstring
- _require_cross_replica_context_extended(self)
-
- # TODO(priyag): Remove this when all callers have been updated.
- if isinstance(reduce_op, variable_scope.VariableAggregation):
- assert reduce_op in [
- variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN,
- ]
- reduce_op = reduce_util.ReduceOp.from_variable_aggregation(reduce_op)
- return self._batch_reduce_to(reduce_op, value_destination_pairs)
-
- def _batch_reduce_to(self, reduce_op, value_destination_pairs):
- return [
- self.reduce_to(reduce_op, t, destinations=v)
- for t, v in value_destination_pairs
- ]
-
- def update(self, var, fn, args=(), kwargs=None, group=True):
- """Run `fn` to update `var` using inputs mirrored to the same devices.
-
- If `var` is mirrored across multiple devices, then this implements
- logic like:
-
- ```
- results = {}
- for device, v in var:
- with tf.device(device):
- # args and kwargs will be unwrapped if they are mirrored.
- results[device] = fn(v, *args, **kwargs)
- return merged(results)
- ```
-
- Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.
-
- Neither `args` nor `kwargs` may contain per-replica values.
- If they contain mirrored values, they will be unwrapped before
- calling `fn`.
-
- Args:
- var: Variable, possibly mirrored to multiple devices, to operate on.
- fn: Function to call. Should take the variable as the first argument.
- args: Tuple or list. Additional positional arguments to pass to `fn()`.
- kwargs: Dict with keyword arguments to pass to `fn()`.
- group: Boolean. Defaults to True. If False, the return value will be
- unwrapped.
-
- Returns:
- By default, the merged return value of `fn` across all replicas. The
- merged result has dependencies to make sure that if it is evaluated at
- all, the side effects (updates) will happen on every replica. If instead
- "group=False" is specified, this function will return a nest of lists
- where each list has an element per replica, and the caller is responsible
- for ensuring all elements are executed.
- """
- _require_cross_replica_context_extended(self)
- if kwargs is None:
- kwargs = {}
- return self._update(var, fn, args, kwargs, group)
-
- def _update(self, var, fn, args, kwargs, group):
- raise NotImplementedError("must be implemented in descendants")
-
- def update_non_slot(
- self, colocate_with, fn, args=(), kwargs=None, group=True):
- """Runs `fn(*args, **kwargs)` on `colocate_with` devices.
-
- Args:
- colocate_with: The return value of `non_slot_devices()`.
- fn: Function to execute.
- args: Tuple or list. Positional arguments to pass to `fn()`.
- kwargs: Dict with keyword arguments to pass to `fn()`.
- group: Boolean. Defaults to True. If False, the return value will be
- unwrapped.
-
- Returns:
- Return value of `fn`, possibly merged across devices.
- """
- _require_cross_replica_context_extended(self)
- if kwargs is None:
- kwargs = {}
- return self._update_non_slot(colocate_with, fn, args, kwargs, group)
-
- def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
- raise NotImplementedError("must be implemented in descendants")
-
- def _unwrap(self, distributed_value):
- raise NotImplementedError("must be implemented in descendants")
-
- def value_container(self, value):
- """Returns the container that this per-replica `value` belongs to.
-
- Args:
- value: A value returned by `call_for_each_replica()` or a variable
- created in `scope()`.
-
- Returns:
- A container that `value` belongs to.
- If value does not belong to any container (including the case of
- container having been destroyed), returns the value itself.
- `value in unwrap(value_container(value))` will always be true.
- """
- raise NotImplementedError("must be implemented in descendants")
-
- def _group(self, value, name=None):
- """Shortcut for `tf.group(distribution.unwrap(value))`."""
- value = nest.flatten(self._unwrap(value))
-
- if len(value) != 1 or name is not None:
- return control_flow_ops.group(value, name=name)
- # Special handling for the common case of one op.
- v, = value
- if hasattr(v, "op"):
- v = v.op
- return v
-
- @property
- def experimental_require_static_shapes(self):
- return self._require_static_shapes
-
- @property
- def _num_replicas_in_sync(self):
- """Returns number of replicas over which gradients are aggregated."""
- raise NotImplementedError("must be implemented in descendants")
-
- @property
- def worker_devices(self):
- """Returns the list of devices used to run `call_for_each_replica()` calls.
- """
- # TODO(josh11b): More docstring
- raise NotImplementedError("must be implemented in descendants")
-
- @property
- def parameter_devices(self):
- """Returns the list of devices used for variable and `update` placement."""
- # TODO(josh11b): More docstring
- raise NotImplementedError("must be implemented in descendants")
-
- def non_slot_devices(self, var_list):
- """Device(s) for non-slot variables.
-
- Create variables on these devices in a
- `with colocate_vars_with(non_slot_devices(...)):` block.
- Update those using `update_non_slot()`.
-
- Args:
- var_list: The list of variables being optimized, needed with the
- default `tf.distribute.Strategy`.
- """
- raise NotImplementedError("must be implemented in descendants")
-
- @property
- def experimental_between_graph(self):
- """Whether the strategy uses between-graph replication or not.
-
- This is expected to return a constant value that will not be changed
- throughout its life cycle.
- """
- raise NotImplementedError("must be implemented in descendants")
-
- def _configure(self,
- session_config=None,
- cluster_spec=None,
- task_type=None,
- task_id=None):
- """Configures the strategy class."""
- del session_config, cluster_spec, task_type, task_id
-
- @property
- def experimental_should_init(self):
- """Whether initialization is needed."""
- raise NotImplementedError("must be implemented in descendants")
-
- @property
- def should_checkpoint(self):
- """Whether checkpointing is needed."""
- raise NotImplementedError("must be implemented in descendants")
-
- @property
- def should_save_summary(self):
- """Whether saving summaries is needed."""
- raise NotImplementedError("must be implemented in descendants")
-
-
-# A note about the difference between the context managers
-# `ReplicaContext` (defined here) and `_CurrentDistributionContext`
-# (defined above) used by `DistributionStrategy.scope()`:
-#
-# * a ReplicaContext is only present during a `call_for_each_replica()`
-# call (except during a `merge_run` call) and in such a scope it
-# will be returned by calls to `get_replica_context()`. Implementers of new
-# DistributionStrategy descendants will frequently also need to
-# define a descendant of ReplicaContext, and are responsible for
-# entering and exiting this context.
-#
-# * DistributionStrategy.scope() sets up a variable_creator scope that
-# changes variable creation calls (e.g. to make mirrored
-# variables). This is intended as an outer scope that users enter once
-# around their model creation and graph definition. There is no
-# anticipated need to define descendants of _CurrentDistributionContext.
-# It sets the current DistributionStrategy for purposes of
-# `get_strategy()` and `has_strategy()`
-# and switches the thread mode to a "cross-replica context".
-@tf_export("distribute.ReplicaContext")
-class ReplicaContext(object):
- """`tf.distribute.Strategy` API when in a replica context.
-
- To be used inside your replicated step function, such as in a
- `tf.distribute.StrategyExtended.call_for_each_replica` call.
- """
-
- def __init__(self, strategy, replica_id_in_sync_group):
- self._distribution_strategy = strategy
- self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
- self)
- self._replica_id_in_sync_group = replica_id_in_sync_group
-
- def __enter__(self):
- _push_per_thread_mode(self._thread_context)
-
- def __exit__(self, exception_type, exception_value, traceback):
- _pop_per_thread_mode()
-
- def merge_call(self, merge_fn, args=(), kwargs=None):
- """Merge args across replicas and run `merge_fn` in a cross-replica context.
-
- This allows communication and coordination when there are multiple calls
- to a model function triggered by a call to
- `strategy.extended.call_for_each_replica(model_fn, ...)`.
-
- See `tf.distribute.StrategyExtended.call_for_each_replica` for an
- explanation.
-
- If not inside a distributed scope, this is equivalent to:
-
- ```
- strategy = tf.distribute.get_strategy()
- with cross-replica-context(strategy):
- return merge_fn(strategy, *args, **kwargs)
- ```
-
- Args:
- merge_fn: function that joins arguments from threads that are given as
- PerReplica. It accepts `tf.distribute.Strategy` object as
- the first argument.
- args: List or tuple with positional per-thread arguments for `merge_fn`.
- kwargs: Dict with keyword per-thread arguments for `merge_fn`.
-
- Returns:
- The return value of `merge_fn`, except for `PerReplica` values which are
- unpacked.
- """
- require_replica_context(self)
- if kwargs is None:
- kwargs = {}
- return self._merge_call(merge_fn, args, kwargs)
-
- def _merge_call(self, merge_fn, args, kwargs):
- """Default implementation for single replica."""
- _push_per_thread_mode( # thread-local, so not needed with multiple threads
- distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
- self._distribution_strategy))
- try:
- return merge_fn(self._distribution_strategy, *args, **kwargs)
- finally:
- _pop_per_thread_mode()
-
- @property
- def num_replicas_in_sync(self):
- """Returns number of replicas over which gradients are aggregated."""
- return self._distribution_strategy.num_replicas_in_sync
-
- @property
- def replica_id_in_sync_group(self):
- """Which replica is being defined, from 0 to `num_replicas_in_sync - 1`."""
- require_replica_context(self)
- return self._replica_id_in_sync_group
-
- @property
- @doc_controls.do_not_generate_docs # DEPRECATED, use `strategy`
- def distribution_strategy(self):
- """DEPRECATED: use `self.stratgey` instead."""
- return self._distribution_strategy
-
- @property
- def strategy(self):
- """The current `tf.distribute.Strategy` object."""
- return self._distribution_strategy
-
- @property
- def devices(self):
- """The devices this replica is to be executed on, as a list of strings."""
- require_replica_context(self)
- return [device_util.current()]
-
- # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient
- # all-reduce. It would return a function returning the result of reducing `t`
- # across all replicas. The caller would wait to call this function until they
- # needed the reduce result, allowing an efficient implementation:
- # * With eager execution, the reduction could be performed asynchronously
- # in the background, not blocking until the result was needed.
- # * When constructing a graph, it could batch up all reduction requests up
- # to that point that the first result is needed. Most likely this can be
- # implemented in terms of `merge_call()` and `batch_reduce_to()`.
-
-# ------------------------------------------------------------------------------
-
-
-class _DefaultDistributionStrategy(DistributionStrategy):
- """Default `tf.distribute.Strategy` if none is explicitly selected."""
-
- def __init__(self):
- super(_DefaultDistributionStrategy, self).__init__(
- _DefaultDistributionExtended(self))
-
-
-class _DefaultDistributionExtended(DistributionStrategyExtended):
- """Implementation of _DefaultDistributionStrategy."""
-
- def _scope(self, strategy):
- """Context manager setting a variable creator and `self` as current."""
- if distribution_strategy_context.has_distribution_strategy():
- raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
-
- def creator(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope_strategy(strategy)
- return next_creator(*args, **kwargs)
-
- return _CurrentDistributionContext(
- strategy, variable_scope.variable_creator_scope(creator))
-
- def colocate_vars_with(self, colocate_with_variable):
- """Does not require `self.scope`."""
- _require_distribution_strategy_scope_extended(self)
- return ops.colocate_with(colocate_with_variable)
-
- def _distribute_dataset(self, dataset_fn):
- return self._call_dataset_fn(dataset_fn)
-
- def _make_dataset_iterator(self, dataset):
- return _DefaultDistributionExtended.DefaultInputIterator(dataset)
-
- def _make_input_fn_iterator(self,
- input_fn,
- replication_mode=InputReplicationMode.PER_WORKER):
- return input_fn(InputContext()).make_initializable_iterator()
-
- def _broadcast_to(self, tensor, destinations):
- if destinations is None:
- return tensor
- else:
- raise NotImplementedError("TODO")
-
- def _call_for_each_replica(self, fn, args, kwargs):
- with ReplicaContext(
- self._container_strategy(),
- replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
- return fn(*args, **kwargs)
-
- def _reduce_to(self, reduce_op, value, destinations):
- # TODO(josh11b): Use destinations?
- del reduce_op, destinations
- return value
-
- def _update(self, var, fn, args, kwargs, group):
- # The implementations of _update() and _update_non_slot() are identical
- # except _update() passes `var` as the first argument to `fn()`.
- return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group)
-
- def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group):
- # TODO(josh11b): Figure out what we should be passing to UpdateContext()
- # once that value is used for something.
- with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
- result = fn(*args, **kwargs)
- if should_group:
- return result
- else:
- return nest.map_structure(self._unwrap, result)
-
- def read_var(self, replica_local_var):
- return array_ops.identity(replica_local_var)
-
- def _unwrap(self, distributed_value):
- return [distributed_value]
-
- def value_container(self, value):
- return value
-
- @property
- def _num_replicas_in_sync(self):
- return 1
-
- @property
- def worker_devices(self):
- raise RuntimeError("worker_devices() method unsupported by default "
- "tf.distribute.Strategy.")
-
- @property
- def parameter_devices(self):
- raise RuntimeError("parameter_devices() method unsupported by default "
- "tf.distribute.Strategy.")
-
- def non_slot_devices(self, var_list):
- return min(var_list, key=lambda x: x.name)
-
- # TODO(priyag): This should inherit from `InputIterator`, once dependency
- # issues have been resolved.
- class DefaultInputIterator(object):
- """Default implementation of `InputIterator` for default strategy."""
-
- def __init__(self, dataset):
- self._dataset = dataset
- if eager_context.executing_eagerly():
- self._iterator = dataset.make_one_shot_iterator()
- else:
- self._iterator = dataset.make_initializable_iterator()
-
- def get_next(self):
- return self._iterator.get_next()
-
- def initialize(self):
- if eager_context.executing_eagerly():
- self._iterator = self._dataset.make_one_shot_iterator()
- return []
- else:
- return [self._iterator.initializer]
-
-
-# ------------------------------------------------------------------------------
-# We haven't yet implemented deserialization for DistributedVariables.
-# So here we catch any attempts to deserialize variables
-# when using distribution strategies.
-# pylint: disable=protected-access
-_original_from_proto = resource_variable_ops._from_proto_fn
-
-
-def _from_proto_fn(v, import_scope=None):
- if distribution_strategy_context.has_distribution_strategy():
- raise NotImplementedError(
- "Deserialization of variables is not yet supported when using a "
- "tf.distribute.Strategy.")
- else:
- return _original_from_proto(v, import_scope=import_scope)
-
-resource_variable_ops._from_proto_fn = _from_proto_fn
-# pylint: enable=protected-access
-
-
-#-------------------------------------------------------------------------------
-# Shorthand for some methods from distribution_strategy_context.
-_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access
-_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access
-_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access
+# pylint: disable=wildcard-import
+from tensorflow.python.distribute.distribute_lib import *
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
index 0b3878d..7391bf3 100644
--- a/tensorflow/python/training/distribution_strategy_context.py
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -12,225 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Utility to get distribution strategy related contexts."""
+"""Deprecated, please use ../distribute/distribution_strategy_context.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
-from tensorflow.python.util.lazy_loader import LazyLoader
-from tensorflow.python.util.tf_export import tf_export
-
-
-# There is a circular dependency between this and `distribute` module. So we
-# load it lazily to workaround this.
-distribute_lib = LazyLoader(
- "distribute_lib", globals(),
- "tensorflow.python.training.distribute")
-
-# ------------------------------------------------------------------------------
-# Internal API for setting the current thread mode as being either in a
-# replica or cross-replica context for a particular distribution strategy.
-
-
-class _ThreadMode(object):
-
- def __init__(self, dist, cross, replica):
- self.distribution_strategy = dist
- self.cross_replica_context = cross
- self.replica_context = replica
-
-
-class _CrossReplicaThreadMode(_ThreadMode):
-
- def __init__(self, distribution_strategy):
- _ThreadMode.__init__(
- self, distribution_strategy, distribution_strategy, None)
-
-
-class _InReplicaThreadMode(_ThreadMode):
-
- def __init__(self, replica_ctx):
- _ThreadMode.__init__(
- self, replica_ctx.distribution_strategy, None, replica_ctx)
-
-
-def _push_per_thread_mode(context):
- ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access
-
-
-def _pop_per_thread_mode():
- ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access
-
-
-class _DefaultReplicaThreadMode(_ThreadMode):
- """Type of default value returned by `_get_per_thread_mode()`.
-
- Used when the thread-local stack is empty.
- """
-
- def __init__(self):
- _ThreadMode.__init__(self, _get_default_distribution_strategy(), None,
- _get_default_replica_context())
-
-
-def _get_per_thread_mode():
- try:
- return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access
- except (AttributeError, IndexError):
- return _get_default_replica_mode()
-
-
-# ------------------------------------------------------------------------------
-# Public API for accessing the current thread mode
-
-
-@tf_export("distribute.get_replica_context")
-def get_replica_context():
- """Returns the current `tf.distribute.ReplicaContext` or `None`.
-
- Returns `None` if in a cross-replica context.
-
- Note that execution:
-
- 1. starts in the default (single-replica) replica context (this function
- will return the default `ReplicaContext` object);
- 2. switches to cross-replica context (in which case this will return
- `None`) when entering a `with tf.distribute.Strategy.scope():` block;
- 3. switches to a (non-default) replica context inside
- `extended.call_for_each_replica(fn, ...)`;
- 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-replica context (and again
- this function will return `None`).
-
- Note that you can also go directly from step 1 to 4 to switch to a
- cross-replica context for the default `tf.distribute.Strategy`. You may
- also switch from the cross-replica context of 4 to a replica context by
- calling `extended.call_for_each_replica()`, jumping back to step 3.
-
- Most `tf.distribute.Strategy` methods may only be executed in
- a cross-replica context, in a replica context you should use the
- `ReplicaContext` API instead.
-
- Returns:
- The current `ReplicaContext` object when in a replica context scope,
- else `None`.
-
- Within a particular block, exactly one of these two things will be true:
-
- * `get_replica_context()` returns non-`None`, or
- * `tf.distribute.is_cross_replica_context()` returns True.
- """
- return _get_per_thread_mode().replica_context
-
-
-def get_cross_replica_context():
- """Returns the current tf.distribute.Strategy if in a cross-replica context.
-
- DEPRECATED: Please use `in_cross_replica_context()` and
- `get_distribution_strategy()` instead.
-
- Note that execution:
-
- 1. starts in the default (single-replica) replica context;
- 2. switches to cross-replica context when entering a
- `with tf.distribute.Strategy.scope():` block;
- 3. switches to a (non-default) replica context inside
- `call_for_each_replica(fn, ...)`;
- 4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-replica context.
-
- Note that you can also go directly from step 1 to 4 to switch to a
- cross-replica context for the default `tf.distribute.Strategy`. You may
- also switch from the cross-replica context of 4 to a replica context by
- calling `call_for_each_replica()`, jumping back to step 3.
-
- Most `tf.distribute.Strategy` methods may only be executed in
- a cross-replica context.
-
- Returns:
- Returns the current `tf.distribute.Strategy` object in a cross-replica
- context, or `None`.
-
- Exactly one of `get_replica_context()` and `get_cross_replica_context()`
- will return `None` in a particular block.
- """
- return _get_per_thread_mode().cross_replica_context
-
-
-@tf_export("distribute.in_cross_replica_context")
-def in_cross_replica_context():
- """Returns True if in a cross-replica context.
-
- See `tf.distribute.get_replica_context` for details.
-
- Returns:
- True if in a cross-replica context (`get_replica_context()` returns
- `None`), or False if in a replica context (`get_replica_context()` returns
- non-`None`).
- """
- return _get_per_thread_mode().cross_replica_context is not None
-
-
-@tf_export("distribute.get_strategy")
-def get_distribution_strategy():
- """Returns the current `tf.distribute.Strategy` object.
-
- Typically only used in a cross-replica context:
-
- ```
- if tf.distribute.in_cross_replica_context():
- strategy = tf.distribute.get_strategy()
- ...
- ```
-
- Returns:
- A `tf.distribute.Strategy` object. Inside a
- `with distribution_strategy.scope()` block, it returns
- `distribution_strategy`, otherwise it returns the default
- (single-replica) `tf.distribute.Strategy` object.
- """
- return _get_per_thread_mode().distribution_strategy
-
-
-@tf_export("distribute.has_strategy")
-def has_distribution_strategy():
- """Return if there is a current non-default `tf.distribute.Strategy`.
-
- Returns:
- True if inside a `with strategy.scope():`.
- """
- return get_distribution_strategy() is not _get_default_distribution_strategy()
-
-
-# ------------------------------------------------------------------------------
-# Defaults that are used when no distribution strategy is explicitly created.
-# We create them lazily in a function so that we can workaround the circular
-# dependency on distribute_lib. See lazy loader at the top of this file.
-
-_defaults = {
- "distribution_strategy": None,
- "replica_context": None,
- "replica_mode": None
-}
-
-
-def _get_default_distribution_strategy():
- if _defaults["distribution_strategy"] is None:
- _defaults["distribution_strategy"] = (
- distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access
- return _defaults["distribution_strategy"]
-
-
-def _get_default_replica_context():
- if _defaults["replica_context"] is None:
- _defaults["replica_context"] = distribute_lib.ReplicaContext(
- _get_default_distribution_strategy(), replica_id_in_sync_group=0)
- return _defaults["replica_context"]
-
-
-def _get_default_replica_mode():
- if _defaults["replica_mode"] is None:
- _defaults["replica_mode"] = _DefaultReplicaThreadMode()
- return _defaults["replica_mode"]
+# pylint: disable=wildcard-import
+from tensorflow.python.distribute.distribution_strategy_context import *
diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py
index 2c4eb02..a10178f 100644
--- a/tensorflow/python/training/evaluation.py
+++ b/tensorflow/python/training/evaluation.py
@@ -230,7 +230,7 @@
hooks = list(hooks or [])
if eval_ops is not None:
- if any([isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks]):
+ if any(isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks):
steps_per_run_variable = \
basic_session_run_hooks.get_or_create_steps_per_run_variable()
update_eval_step = state_ops.assign_add(
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
index 2fafc9a..a2ef3c7 100644
--- a/tensorflow/python/training/ftrl.py
+++ b/tensorflow/python/training/ftrl.py
@@ -25,7 +25,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.FtrlOptimizer")
+@tf_export(v1=["train.FtrlOptimizer"])
class FtrlOptimizer(optimizer.Optimizer):
"""Optimizer that implements the FTRL algorithm.
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
index a61132a..70b5db3 100644
--- a/tensorflow/python/training/ftrl_test.py
+++ b/tensorflow/python/training/ftrl_test.py
@@ -54,7 +54,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([0.0, 0.0], v0_val)
self.assertAllClose([0.0, 0.0], v1_val)
@@ -62,7 +62,7 @@
for _ in range(3):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-2.60260963, -4.29698515]), v0_val)
self.assertAllCloseAccordingToType(
@@ -90,14 +90,14 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
# Run 3 steps FTRL
for _ in range(3):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-2.55607247, -3.98729396]), v0_val)
self.assertAllCloseAccordingToType(
@@ -137,14 +137,14 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
# Run 10 steps FTRL
for _ in range(10):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-7.66718769, -10.91273689]), v0_val)
self.assertAllCloseAccordingToType(
@@ -166,7 +166,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
@@ -174,7 +174,7 @@
for _ in range(10):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-0.24059935, -0.46829352]), v0_val)
self.assertAllCloseAccordingToType(
@@ -203,7 +203,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([4.0, 3.0], v1_val)
@@ -211,7 +211,7 @@
for _ in range(10):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType(
np.array([-0.22578995, -0.44345796]), v0_val)
self.assertAllCloseAccordingToType(
@@ -239,7 +239,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([[1.0], [2.0]], v0_val)
self.assertAllCloseAccordingToType([[4.0], [3.0]], v1_val)
@@ -247,7 +247,7 @@
for _ in range(10):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([[-0.22578995], [2.]], v0_val)
self.assertAllCloseAccordingToType([[4.], [-0.13229476]], v1_val)
@@ -275,7 +275,7 @@
update1 = opt1.apply_gradients([(grads1, var1)])
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllCloseAccordingToType([1.0, 2.0], v0_val)
self.assertAllCloseAccordingToType([1.0, 2.0], v1_val)
@@ -284,12 +284,12 @@
update0.run()
update1.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
# var0 is experiencing L2 shrinkage so it should be smaller than var1
# in magnitude.
self.assertTrue((v0_val**2 < v1_val**2).all())
- accum0 = list(sess.run(opt0._slots)["accum"].values())[0]
- accum1 = list(sess.run(opt1._slots)["accum"].values())[0]
+ accum0 = list(self.evaluate(opt0._slots)["accum"].values())[0]
+ accum1 = list(self.evaluate(opt1._slots)["accum"].values())[0]
# L2 shrinkage should not change how we update grad accumulator.
self.assertAllCloseAccordingToType(accum0, accum1)
@@ -313,7 +313,7 @@
variables.global_variables_initializer().run()
sess = ops.get_default_session()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
if is_sparse:
self.assertAllCloseAccordingToType([[0.0], [0.0]], v0_val)
self.assertAllCloseAccordingToType([[0.0], [0.0]], v1_val)
@@ -325,7 +325,7 @@
for _ in range(steps):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
return v0_val, v1_val
# When variables are initialized with Zero, FTRL-Proximal has two properties:
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
index ef50f63..1a52734 100644
--- a/tensorflow/python/training/gradient_descent.py
+++ b/tensorflow/python/training/gradient_descent.py
@@ -26,7 +26,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.GradientDescentOptimizer")
+@tf_export(v1=["train.GradientDescentOptimizer"])
class GradientDescentOptimizer(optimizer.Optimizer):
"""Optimizer that implements the gradient descent algorithm.
"""
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 31c2cc5..327f087 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -256,7 +256,7 @@
# writing of the `tf.Graph` object. However, many users
# write code this way, so we include this test to ensure
# that we can support it.
- self.assertEquals(string, sess.run(queue.dequeue()))
+ self.assertEquals(string, self.evaluate(queue.dequeue()))
coord.request_stop()
coord.join(threads)
@@ -348,14 +348,14 @@
# No randomness, so just see repeated copies of the input.
num_items = len(source_strings) * num_epochs
- output = [sess.run(slices) for _ in range(num_items)]
+ output = [self.evaluate(slices) for _ in range(num_items)]
out_strings, out_ints = zip(*output)
self.assertAllEqual(source_strings * num_epochs, out_strings)
self.assertAllEqual(source_ints * num_epochs, out_ints)
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(slices)
+ self.evaluate(slices)
for thread in threads:
thread.join()
@@ -383,7 +383,7 @@
for e in expected:
frequency[e] = 0
for _ in range(num_epochs):
- output = [sess.run(slices) for _ in range(len(source_strings))]
+ output = [self.evaluate(slices) for _ in range(len(source_strings))]
key = b",".join([s + compat.as_bytes(str(i)) for s, i in output])
self.assertIn(key, expected)
frequency[key] += 1
@@ -399,7 +399,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(slices)
+ self.evaluate(slices)
for thread in threads:
thread.join()
@@ -491,7 +491,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched_fetch)
+ self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@@ -507,7 +507,7 @@
with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
- sess.run(batched)
+ self.evaluate(batched)
coord.request_stop()
for thread in threads:
thread.join()
@@ -518,7 +518,7 @@
with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord)
- sess.run(batched)
+ self.evaluate(batched)
coord.request_stop()
for thread in threads:
thread.join()
@@ -549,7 +549,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -584,7 +584,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -624,7 +624,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -681,7 +681,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -736,7 +736,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -834,7 +834,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -1051,7 +1051,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched_fetch)
+ self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@@ -1148,7 +1148,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -1249,7 +1249,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -1347,7 +1347,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -1421,7 +1421,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -1597,7 +1597,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched_fetch)
+ self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@@ -1659,7 +1659,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched_fetch)
+ self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@@ -1706,7 +1706,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -1764,7 +1764,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -1824,7 +1824,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -2020,7 +2020,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched_fetch)
+ self.evaluate(batched_fetch)
for thread in threads:
thread.join()
@@ -2129,7 +2129,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
@@ -2210,7 +2210,7 @@
# Reached the limit.
with self.assertRaises(errors_impl.OutOfRangeError):
- sess.run(batched)
+ self.evaluate(batched)
for thread in threads:
thread.join()
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index 29b5465..c52e89d 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -100,7 +100,7 @@
return decayed_lr
-@tf_export(v1=["train.piecewise_constant"])
+@tf_export(v1=["train.piecewise_constant_decay", "train.piecewise_constant"])
def piecewise_constant(x, boundaries, values, name=None):
"""Piecewise constant from boundaries and interval values.
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 03a32f6..9c31c09 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -62,23 +62,22 @@
self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
def testVariables(self):
- with self.cached_session():
- step = variables.VariableV1(1)
- assign_1 = step.assign(1)
- assign_2 = step.assign(2)
- assign_100 = step.assign(100)
- decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
- staircase=True)
- variables.global_variables_initializer().run()
- # No change to learning rate
- assign_1.op.run()
- self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
- assign_2.op.run()
- self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
- # Decayed learning rate
- assign_100.op.run()
- expected = .1 * 0.96 ** (100 // 3)
- self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+ step = variables.VariableV1(1)
+ assign_1 = step.assign(1)
+ assign_2 = step.assign(2)
+ assign_100 = step.assign(100)
+ decayed_lr = learning_rate_decay.exponential_decay(
+ .1, step, 3, 0.96, staircase=True)
+ self.evaluate(variables.global_variables_initializer())
+ # No change to learning rate
+ self.evaluate(assign_1.op)
+ self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)
+ self.evaluate(assign_2.op)
+ self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)
+ # Decayed learning rate
+ self.evaluate(assign_100.op)
+ expected = .1 * 0.96**(100 // 3)
+ self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
@test_util.run_in_graph_and_eager_modes
def testPiecewiseConstant(self):
diff --git a/tensorflow/python/training/learning_rate_decay_v2.py b/tensorflow/python/training/learning_rate_decay_v2.py
index 9c5e144..eb69feb 100644
--- a/tensorflow/python/training/learning_rate_decay_v2.py
+++ b/tensorflow/python/training/learning_rate_decay_v2.py
@@ -117,7 +117,7 @@
decay_rate, staircase, name)
-@tf_export("train.piecewise_constant", v1=[])
+@tf_export("train.piecewise_constant_decay", v1=[])
def piecewise_constant(x, boundaries, values, name=None):
"""Piecewise constant from boundaries and interval values.
diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py
index b2ac93f..354ddb2 100644
--- a/tensorflow/python/training/learning_rate_decay_v2_test.py
+++ b/tensorflow/python/training/learning_rate_decay_v2_test.py
@@ -62,23 +62,22 @@
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
def testVariables(self):
- with self.cached_session():
- step = variables.Variable(1)
- assign_1 = step.assign(1)
- assign_2 = step.assign(2)
- assign_100 = step.assign(100)
- decayed_lr = learning_rate_decay_v2.exponential_decay(.1, step, 3, 0.96,
- staircase=True)
- variables.global_variables_initializer().run()
- # No change to learning rate
- assign_1.op.run()
- self.assertAllClose(decayed_lr().eval(), .1, 1e-6)
- assign_2.op.run()
- self.assertAllClose(decayed_lr().eval(), .1, 1e-6)
- # Decayed learning rate
- assign_100.op.run()
- expected = .1 * 0.96 ** (100 // 3)
- self.assertAllClose(decayed_lr().eval(), expected, 1e-6)
+ step = variables.Variable(1)
+ assign_1 = step.assign(1)
+ assign_2 = step.assign(2)
+ assign_100 = step.assign(100)
+ decayed_lr = learning_rate_decay_v2.exponential_decay(
+ .1, step, 3, 0.96, staircase=True)
+ self.evaluate(variables.global_variables_initializer())
+ # No change to learning rate
+ self.evaluate(assign_1.op)
+ self.assertAllClose(self.evaluate(decayed_lr()), .1, 1e-6)
+ self.evaluate(assign_2.op)
+ self.assertAllClose(self.evaluate(decayed_lr()), .1, 1e-6)
+ # Decayed learning rate
+ self.evaluate(assign_100.op)
+ expected = .1 * 0.96**(100 // 3)
+ self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
@test_util.run_in_graph_and_eager_modes
def testPiecewiseConstant(self):
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
index 4a280e7..f3bc83b 100644
--- a/tensorflow/python/training/momentum.py
+++ b/tensorflow/python/training/momentum.py
@@ -25,7 +25,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.MomentumOptimizer")
+@tf_export(v1=["train.MomentumOptimizer"])
class MomentumOptimizer(optimizer.Optimizer):
"""Optimizer that implements the Momentum algorithm.
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 162fef9..c40bd2b 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -508,7 +508,7 @@
stop_grace_period_secs=stop_grace_period_secs)
-@tf_export('train.SessionCreator')
+@tf_export(v1=['train.SessionCreator'])
@six.add_metaclass(abc.ABCMeta)
class SessionCreator(object):
"""A factory for tf.Session."""
@@ -519,7 +519,7 @@
'create_session is not implemented for {}.'.format(self))
-@tf_export('train.ChiefSessionCreator')
+@tf_export(v1=['train.ChiefSessionCreator'])
class ChiefSessionCreator(SessionCreator):
"""Creates a tf.Session for a chief."""
@@ -571,7 +571,7 @@
init_fn=self._scaffold.init_fn)
-@tf_export('train.WorkerSessionCreator')
+@tf_export(v1=['train.WorkerSessionCreator'])
class WorkerSessionCreator(SessionCreator):
"""Creates a tf.Session for a worker."""
@@ -840,10 +840,18 @@
return self._coordinated_creator.tf_sess is None
def _tf_sess(self):
+ """Return underlying tf.Session object.
+
+ Warning: accessing the returned object in user code is likely to cause races
+ or "flaky tests".
+
+ Returns:
+ A tf.Session object.
+ """
return self._coordinated_creator.tf_sess
-@tf_export('train.MonitoredSession')
+@tf_export(v1=['train.MonitoredSession'])
class MonitoredSession(_MonitoredSession):
"""Session-like object that handles initialization, recovery and hooks.
@@ -926,7 +934,7 @@
stop_grace_period_secs=stop_grace_period_secs)
-@tf_export('train.SingularMonitoredSession')
+@tf_export(v1=['train.SingularMonitoredSession'])
class SingularMonitoredSession(_MonitoredSession):
"""Session-like object that handles initialization, restoring, and hooks.
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index ebe2f15..2ceb387 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -382,6 +382,16 @@
self.assertEqual(0, session.run(gstep))
+class MockExtended(object):
+
+ def __init__(self, between_graph, should_init, should_checkpoint,
+ should_save_summary):
+ self.experimental_between_graph = between_graph
+ self.experimental_should_init = should_init
+ self.should_checkpoint = should_checkpoint
+ self.should_save_summary = should_save_summary
+
+
class MockStrategy(object):
def __init__(self,
@@ -389,26 +399,8 @@
should_init=True,
should_checkpoint=None,
should_save_summary=None):
- self._between_graph = between_graph
- self._should_init = should_init
- self._should_checkpoint = should_checkpoint
- self._should_save_summary = should_save_summary
-
- @property
- def between_graph(self):
- return self._between_graph
-
- @property
- def should_init(self):
- return self._should_init
-
- @property
- def should_checkpoint(self):
- return self._should_checkpoint
-
- @property
- def should_save_summary(self):
- return self._should_save_summary
+ self.extended = MockExtended(between_graph, should_init, should_checkpoint,
+ should_save_summary)
class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 41e9dce..6ce5de6 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -278,7 +278,7 @@
self.evaluate(v0.initializer)
self.assertEqual([10.0], self.evaluate(v1_avg))
# running ema_op should add to v0 (in addition to updating v1_avg)
- sess.run(assign_to_v1)
+ self.evaluate(assign_to_v1)
self.evaluate(ema_op)
self.assertEqual(1, self.evaluate(v0))
self.assertEqual([17.5], self.evaluate(v1_avg))
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 6fca4ca..900afee 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -24,6 +24,7 @@
import six
+from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@@ -37,7 +38,6 @@
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import distribution_strategy_context as distribute_ctx
from tensorflow.python.training import slot_creator
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -214,7 +214,7 @@
raise NotImplementedError("Trying to optimize unsupported type ", v)
-@tf_export("train.Optimizer")
+@tf_export(v1=["train.Optimizer"])
class Optimizer(
# Optimizers inherit from CheckpointableBase rather than Checkpointable
# since they do most of their dependency management themselves (slot
@@ -660,7 +660,7 @@
replicas. If `global_step` was not None, that operation also
increments `global_step`
"""
- reduced_grads = distribution.batch_reduce(
+ reduced_grads = distribution.extended.batch_reduce_to(
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
var_list = [v for _, v in grads_and_vars]
grads_and_vars = zip(reduced_grads, var_list)
@@ -695,21 +695,23 @@
update_ops = [
op
for grad, var in grads_and_vars
- for op in distribution.update(var, update, grad, grouped=False)
+ for op in distribution.extended.update(
+ var, update, args=(grad,), group=False)
]
def finish(self, update_ops):
return self._finish(update_ops, "update")
- non_slot_devices = distribution.non_slot_devices(var_list)
- finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, self, update_ops, grouped=False)
+ non_slot_devices = distribution.extended.non_slot_devices(var_list)
+ finish_updates = distribution.extended.update_non_slot(
+ non_slot_devices, finish, args=(self, update_ops), group=False)
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
with ops.control_dependencies(finish_updates):
- apply_updates = distribution.update(
- global_step, state_ops.assign_add, 1, name=name)
+ apply_updates = distribution.extended.update(
+ global_step, state_ops.assign_add, args=(1,),
+ kwargs={"name": name})
if not context.executing_eagerly():
if isinstance(apply_updates, ops.Tensor):
diff --git a/tensorflow/python/training/proximal_adagrad.py b/tensorflow/python/training/proximal_adagrad.py
index 9bd677b..2ea628a 100644
--- a/tensorflow/python/training/proximal_adagrad.py
+++ b/tensorflow/python/training/proximal_adagrad.py
@@ -26,7 +26,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.ProximalAdagradOptimizer")
+@tf_export(v1=["train.ProximalAdagradOptimizer"])
class ProximalAdagradOptimizer(optimizer.Optimizer):
# pylint: disable=line-too-long
"""Optimizer that implements the Proximal Adagrad algorithm.
diff --git a/tensorflow/python/training/proximal_adagrad_test.py b/tensorflow/python/training/proximal_adagrad_test.py
index 272f901..9d46a66 100644
--- a/tensorflow/python/training/proximal_adagrad_test.py
+++ b/tensorflow/python/training/proximal_adagrad_test.py
@@ -48,7 +48,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([0.0, 0.0], v0_val)
self.assertAllClose([0.0, 0.0], v1_val)
@@ -56,7 +56,7 @@
for _ in range(3):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose(np.array([-2.60260963, -4.29698515]), v0_val)
self.assertAllClose(np.array([-0.28432083, -0.56694895]), v1_val)
opt_vars = opt.variables()
@@ -85,14 +85,14 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([1.0, 2.0], v0_val)
self.assertAllClose([4.0, 3.0], v1_val)
# Run 3 steps Proximal Adagrad.
for _ in range(3):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose(np.array([-1.60261, -2.296985]), v0_val)
self.assertAllClose(np.array([3.715679, 2.433051]), v1_val)
@@ -129,14 +129,14 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([1.0, 2.0], v0_val)
self.assertAllClose([4.0, 3.0], v1_val)
# Run 10 steps Proximal Adagrad
for _ in range(10):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose(np.array([-6.663634, -9.190331]), v0_val)
self.assertAllClose(np.array([2.959304, 1.029232]), v1_val)
@@ -155,7 +155,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([1.0, 2.0], v0_val)
self.assertAllClose([4.0, 3.0], v1_val)
@@ -163,7 +163,7 @@
for _ in range(10):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose(np.array([-0.0495, -0.0995]), v0_val)
self.assertAllClose(np.array([-0.0045, -0.0095]), v1_val)
@@ -191,7 +191,7 @@
variables.global_variables_initializer().run()
sess = ops.get_default_session()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
if is_sparse:
self.assertAllClose([[1.0], [2.0]], v0_val)
self.assertAllClose([[3.0], [4.0]], v1_val)
@@ -203,7 +203,7 @@
for _ in range(steps):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
return v0_val, v1_val
def testEquivAdagradwithoutRegularization(self):
diff --git a/tensorflow/python/training/proximal_gradient_descent_test.py b/tensorflow/python/training/proximal_gradient_descent_test.py
index a9355f4..8797b30 100644
--- a/tensorflow/python/training/proximal_gradient_descent_test.py
+++ b/tensorflow/python/training/proximal_gradient_descent_test.py
@@ -50,7 +50,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([0.0, 0.0], v0_val)
self.assertAllClose([0.0, 0.0], v1_val)
@@ -58,7 +58,7 @@
for _ in range(3):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose(np.array([-0.9, -1.8]), v0_val)
self.assertAllClose(np.array([-0.09, -0.18]), v1_val)
@@ -80,7 +80,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([1.0, 2.0], v0_val)
self.assertAllClose([4.0, 3.0], v1_val)
@@ -88,7 +88,7 @@
for _ in range(3):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose(np.array([0.1, 0.2]), v0_val)
self.assertAllClose(np.array([3.91, 2.82]), v1_val)
@@ -123,7 +123,7 @@
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose([1.0, 2.0], v0_val)
self.assertAllClose([4.0, 3.0], v1_val)
@@ -131,7 +131,7 @@
for _ in range(10):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
self.assertAllClose(np.array([-0.0495, -0.0995]), v0_val)
self.assertAllClose(np.array([-0.0045, -0.0095]), v1_val)
@@ -159,7 +159,7 @@
variables.global_variables_initializer().run()
sess = ops.get_default_session()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
if is_sparse:
self.assertAllClose([[1.0], [2.0]], v0_val)
self.assertAllClose([[3.0], [4.0]], v1_val)
@@ -171,7 +171,7 @@
for _ in range(steps):
update.run()
- v0_val, v1_val = sess.run([var0, var1])
+ v0_val, v1_val = self.evaluate([var0, var1])
return v0_val, v1_val
def testEquivSparseGradientDescentwithoutRegularization(self):
diff --git a/tensorflow/python/training/quantize_training_test.py b/tensorflow/python/training/quantize_training_test.py
index 6edbf76..07fd488 100644
--- a/tensorflow/python/training/quantize_training_test.py
+++ b/tensorflow/python/training/quantize_training_test.py
@@ -73,11 +73,11 @@
_ = importer.import_graph_def(result, name='')
# Initialize the variable.
- sess.run(g.get_operation_by_name(init_op.name))
+ self.evaluate(g.get_operation_by_name(init_op.name))
# Run the graph for one step to assign values to the quantization min/max
# variables.
- sess.run(g.get_tensor_by_name(c.name))
+ self.evaluate(g.get_tensor_by_name(c.name))
saver.save(sess, save_path)
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
index f38c986..fb53b58 100644
--- a/tensorflow/python/training/rmsprop.py
+++ b/tensorflow/python/training/rmsprop.py
@@ -50,7 +50,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.RMSPropOptimizer")
+@tf_export(v1=["train.RMSPropOptimizer"])
class RMSPropOptimizer(optimizer.Optimizer):
"""Optimizer that implements the RMSProp algorithm.
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
index 9ec315f..a9b8954 100644
--- a/tensorflow/python/training/rmsprop_test.py
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -28,6 +28,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -92,7 +93,7 @@
# TODO(yori): Use ParameterizedTest when available
for (dtype, learning_rate, decay, momentum,
epsilon, centered, use_resource) in _TESTPARAMS:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
@@ -115,7 +116,7 @@
centered=centered)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
mg0 = opt.get_slot(var0, "mg")
self.assertEqual(mg0 is not None, centered)
@@ -143,7 +144,7 @@
# Run 4 steps of RMSProp
for _ in range(1, 5):
- update.run()
+ self.evaluate(update)
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate,
@@ -176,11 +177,11 @@
momentum=0.0,
epsilon=0.0,
centered=False).minimize(loss)
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0))
# Run 1 step of sgd
- sgd_op.run()
+ self.evaluate(sgd_op)
# Validate updated params
self.assertAllCloseAccordingToType([[0., 1.]],
self.evaluate(var0),
@@ -199,11 +200,11 @@
momentum=0.0,
epsilon=1.0,
centered=True).minimize(loss)
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
self.assertAllCloseAccordingToType([[1.0, 2.0]], self.evaluate(var0))
# Run 1 step of sgd
- sgd_op.run()
+ self.evaluate(sgd_op)
# Validate updated params
self.assertAllCloseAccordingToType([[-111, -138]],
self.evaluate(var0),
@@ -213,7 +214,7 @@
# TODO(yori): Use ParameterizedTest when available
for (dtype, learning_rate, decay,
momentum, epsilon, centered, _) in _TESTPARAMS:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
# Initialize variables for numpy implementation.
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
@@ -237,7 +238,7 @@
epsilon=epsilon,
centered=centered)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
mg0 = opt.get_slot(var0, "mg")
self.assertEqual(mg0 is not None, centered)
@@ -265,7 +266,7 @@
# Run 4 steps of RMSProp
for _ in range(1, 5):
- update.run()
+ self.evaluate(update)
var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
@@ -287,7 +288,7 @@
def testWithoutMomentum(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -295,7 +296,7 @@
opt = rmsprop.RMSPropOptimizer(
learning_rate=2.0, decay=0.9, momentum=0.0, epsilon=1.0)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
rms0 = opt.get_slot(var0, "rms")
self.assertTrue(rms0 is not None)
@@ -311,7 +312,7 @@
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Step 1: the rms accumulators where 1. So we should see a normal
# update: v -= grad * learning_rate
- update.run()
+ self.evaluate(update)
# Check the root mean square accumulators.
self.assertAllCloseAccordingToType(
np.array([0.901, 0.901]), self.evaluate(rms0))
@@ -329,7 +330,7 @@
4.0 - (0.01 * 2.0 / math.sqrt(0.90001 + 1.0))
]), self.evaluate(var1))
# Step 2: the root mean square accumulators contain the previous update.
- update.run()
+ self.evaluate(update)
# Check the rms accumulators.
self.assertAllCloseAccordingToType(
np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]),
@@ -355,7 +356,7 @@
def testWithMomentum(self):
for dtype in [dtypes.half, dtypes.float32]:
- with self.cached_session(use_gpu=True):
+ with test_util.use_gpu():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -364,7 +365,7 @@
opt = rmsprop.RMSPropOptimizer(
learning_rate=2.0, decay=0.9, momentum=0.5, epsilon=1e-5)
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ self.evaluate(variables.global_variables_initializer())
rms0 = opt.get_slot(var0, "rms")
self.assertTrue(rms0 is not None)
@@ -380,7 +381,7 @@
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
# Step 1: rms = 1, mom = 0. So we should see a normal
# update: v -= grad * learning_rate
- update.run()
+ self.evaluate(update)
# Check the root mean square accumulators.
self.assertAllCloseAccordingToType(
np.array([0.901, 0.901]), self.evaluate(rms0))
@@ -409,7 +410,7 @@
]), self.evaluate(var1))
# Step 2: the root mean square accumulators contain the previous update.
- update.run()
+ self.evaluate(update)
# Check the rms accumulators.
self.assertAllCloseAccordingToType(
np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]),
diff --git a/tensorflow/python/training/saver_large_partitioned_variable_test.py b/tensorflow/python/training/saver_large_partitioned_variable_test.py
index 1a44511..8445883 100644
--- a/tensorflow/python/training/saver_large_partitioned_variable_test.py
+++ b/tensorflow/python/training/saver_large_partitioned_variable_test.py
@@ -25,6 +25,7 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
@@ -44,8 +45,12 @@
# split into smaller sized variables.
init = lambda shape, dtype, partition_info: constant_op.constant(
True, dtype, shape)
- partitioned_var = partitioned_variables.create_partitioned_variables(
- [1 << 31], [4], init, dtype=dtypes.bool, name=var_name)
+ partitioned_var = list(variable_scope.get_variable(
+ var_name,
+ shape=[1 << 31],
+ partitioner=partitioned_variables.fixed_size_partitioner(4),
+ initializer=init,
+ dtype=dtypes.bool))
variables.global_variables_initializer().run()
save = saver.Saver(partitioned_var)
val = save.save(sess, save_path)
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 7bc0a17..6b2177b 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -407,7 +407,7 @@
with self.assertRaisesWithPredicateMatch(
errors_impl.OpError, lambda e: "uninitialized value v" in e.message):
- sess.run(v)
+ self.evaluate(v)
# Restore the saved values in the parameter nodes.
save.restore(sess, save_path)
@@ -497,10 +497,10 @@
with self.assertRaisesWithPredicateMatch(
errors_impl.OpError, lambda e: "uninitialized value v0" in e.message):
- sess.run(v0)
+ self.evaluate(v0)
with self.assertRaisesWithPredicateMatch(
errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
- sess.run(v1)
+ self.evaluate(v1)
self.assertEqual(0, len(v2.keys().eval()))
self.assertEqual(0, len(v2.values().eval()))
@@ -998,19 +998,12 @@
call_saver_with_dict = False # updated by test loop below
- def _save(slices=None, partitioner=None):
+ def _save(partitioner=None):
with self.session(graph=ops_lib.Graph()) as sess:
# Calls .eval() to return the ndarray that makes up the full variable.
rnd = random_ops.random_uniform(var_full_shape).eval()
- if slices:
- assert not partitioner
- # TODO(apassos): make create_partitioned_variables take use_resource
- # option to make this test passable without creating a named
- # variable_scope.
- vs = partitioned_variables.create_partitioned_variables(
- var_full_shape, slices, rnd, name=var_name)
- elif partitioner:
+ if partitioner:
vs = [
variable_scope.get_variable(
var_name,
@@ -1027,7 +1020,7 @@
variables.global_variables_initializer().run()
if call_saver_with_dict:
- saver = saver_module.Saver({var_name: (vs if slices else vs[0])})
+ saver = saver_module.Saver({var_name: vs[0]})
else:
saver = saver_module.Saver(vs)
actual_path = saver.save(sess, saved_path)
@@ -1035,16 +1028,9 @@
return rnd
- def _restore(slices=None, partitioner=None):
+ def _restore(partitioner=None):
with self.session(graph=ops_lib.Graph()) as sess:
- if slices:
- assert not partitioner
- new_vs = partitioned_variables.create_partitioned_variables(
- var_full_shape,
- slices,
- array_ops.zeros(var_full_shape), # != original contents.
- name=var_name)
- elif partitioner:
+ if partitioner:
new_vs = [
variable_scope.get_variable(
var_name,
@@ -1063,7 +1049,7 @@
variables.global_variables_initializer().run()
if call_saver_with_dict:
saver = saver_module.Saver({
- var_name: (new_vs if slices else new_vs[0])
+ var_name: new_vs[0]
})
else:
saver = saver_module.Saver(new_vs)
@@ -1071,11 +1057,7 @@
if partitioner:
return new_vs[0].as_tensor().eval()
- elif slices and slices[0] != 1:
- return array_ops.concat(new_vs, 0).eval()
- elif slices and slices[1] != 1:
- return array_ops.concat(new_vs, 1).eval()
- else: # Non-sliced.
+ else:
return new_vs[0].eval()
for call_saver_with_dict in {False, True}:
@@ -1086,27 +1068,23 @@
restored_full = _restore()
self.assertAllEqual(saved_full, restored_full)
- # Saves 10 horizontal parts of a partitioned variable.
- # Restores into a full variable, non-sliced.
- saved_full = _save(slices=[10, 1])
- restored_full = _restore()
- self.assertAllEqual(saved_full, restored_full)
-
- # Restores into a different number/orientation of slices.
- restored_full = _restore(slices=[2, 1]) # 2 horizon parts.
- self.assertAllEqual(saved_full, restored_full)
- restored_full = _restore(slices=[1, 3]) # 3 vertical parts.
- self.assertAllEqual(saved_full, restored_full)
-
- # Restores into a PartitionedVariable
+ # Restores into the same number of partitions.
restored_full = _restore(
partitioner=partitioned_variables.fixed_size_partitioner(
num_shards=2))
self.assertAllEqual(saved_full, restored_full)
- # Now, saves a full variable and restores in slices.
+ # Restores into a different number of partitions.
+ restored_full = _restore(
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=3))
+ self.assertAllEqual(saved_full, restored_full)
+
+ # Now, saves a full variable and restores PartitionedVariable.
saved_full = _save()
- restored_full = _restore(slices=[1, 3])
+ restored_full = _restore(
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=3))
self.assertAllEqual(saved_full, restored_full)
def testPartitionedVariable(self):
@@ -1769,7 +1747,7 @@
self.assertEqual([], v1.get_shape())
with self.assertRaisesWithPredicateMatch(
errors_impl.OpError, lambda e: "uninitialized value v1" in e.message):
- sess.run(v1)
+ self.evaluate(v1)
# Retrieves saver1. Verifies that new_saver1 can restore v1.
new_saver1 = savers[1]
new_saver1.restore(sess, saver1_ckpt)
@@ -1951,7 +1929,7 @@
# Initializes all the variables.
self.evaluate(init_all_op)
# Runs to logit.
- sess.run(logits)
+ self.evaluate(logits)
# Creates a saver.
saver0 = saver_module.Saver()
saver0.save(sess, saver0_ckpt)
@@ -2038,7 +2016,7 @@
# Generate a MetaGraphDef containing the while loop.
with session.Session() as sess:
self.evaluate(init_op)
- sess.run(output)
+ self.evaluate(output)
saver = saver_module.Saver()
saver.save(sess, saver_ckpt)
saver.export_meta_graph(filename)
@@ -2609,10 +2587,10 @@
saver = saver_module.Saver(var_list=var_list, max_to_keep=1)
saver.restore(sess, os.path.join(test_dir, ckpt_filename))
# Verify that we have restored weights1 and biases1.
- sess.run([weights1, biases1])
+ self.evaluate([weights1, biases1])
# Initialize the rest of the variables and run logits.
self.evaluate(init_rest_op)
- sess.run(logits)
+ self.evaluate(logits)
# Verifies that we can save the subgraph under "hidden1" and restore it
# into "new_hidden1" in the new graph.
diff --git a/tensorflow/python/training/server_lib_same_variables_clear_container_test.py b/tensorflow/python/training/server_lib_same_variables_clear_container_test.py
index 11e6f28..3a5eb71 100644
--- a/tensorflow/python/training/server_lib_same_variables_clear_container_test.py
+++ b/tensorflow/python/training/server_lib_same_variables_clear_container_test.py
@@ -60,9 +60,9 @@
session.Session.reset(server0.target, ["local0"])
sess = session.Session(server0.target)
with self.assertRaises(errors_impl.FailedPreconditionError):
- sess.run(v0)
+ self.evaluate(v0)
# Reinitializes v0 for the following test.
- sess.run(v0.initializer)
+ self.evaluate(v0.initializer)
# Verifies that v1 is still valid.
self.assertAllEqual(2.0, sess_1.run(v1))
@@ -71,10 +71,10 @@
session.Session.reset(server1.target, ["local1"])
sess = session.Session(server1.target)
with self.assertRaises(errors_impl.FailedPreconditionError):
- sess.run(v1)
+ self.evaluate(v1)
# Verifies that v0 is still valid.
sess = session.Session(server0.target)
- self.assertAllEqual(1.0, sess.run(v0))
+ self.assertAllEqual(1.0, self.evaluate(v0))
if __name__ == "__main__":
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index cd313c2..1465863 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -46,7 +46,7 @@
return "<no name for %s>" % type(obj)
-@tf_export("train.SessionManager")
+@tf_export(v1=["train.SessionManager"])
class SessionManager(object):
"""Training helper that restores from checkpoint and creates session.
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index a5e626d..de60dd4 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -40,7 +40,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.Supervisor")
+@tf_export(v1=["train.Supervisor"])
class Supervisor(object):
"""A training helper that checkpoints models and computes summaries.
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index 78dbb46..19dc04e 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -32,7 +32,7 @@
from tensorflow.python.util.tf_export import tf_export
-@tf_export("train.VocabInfo")
+@tf_export(v1=["train.VocabInfo"])
class VocabInfo(
collections.namedtuple("VocabInfo", [
"new_vocab",
@@ -248,7 +248,7 @@
prev_tensor_name = _infer_var_name(var)
# TODO(eddz): Fix functionality for rank-1 Variables (like FC biases).
- total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var])
+ total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var)
for v in var:
v_shape = v.get_shape().as_list()
slice_info = v._get_save_slice_info()
@@ -333,12 +333,12 @@
ops.GraphKeys.TRAINABLE_VARIABLES,
scope=vars_to_warm_start)
elif isinstance(vars_to_warm_start, list):
- if all([isinstance(v, str) for v in vars_to_warm_start]):
+ if all(isinstance(v, str) for v in vars_to_warm_start):
list_of_vars = []
for v in vars_to_warm_start:
list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope=v)
- elif all([checkpoint_utils._is_variable(v) for v in vars_to_warm_start]): # pylint: disable=protected-access
+ elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access
list_of_vars = vars_to_warm_start
else:
raise ValueError("If `vars_to_warm_start` is a list, it must be all "
diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py
index f1e719e..fa1f370 100644
--- a/tensorflow/python/training/warm_starting_util_test.py
+++ b/tensorflow/python/training/warm_starting_util_test.py
@@ -70,7 +70,7 @@
if partitioner:
self.assertTrue(isinstance(var, variables.PartitionedVariable))
var = var._get_variable_list()
- return var, sess.run(var)
+ return var, self.evaluate(var)
def _create_prev_run_vars(self,
var_names,
@@ -86,7 +86,7 @@
shape=shape,
initializer=initializer))
self._write_checkpoint(sess)
- return [sess.run(var) for var in all_vars]
+ return [self.evaluate(var) for var in all_vars]
def _create_dummy_inputs(self):
return {
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index ad91542..4874d09 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -662,8 +662,13 @@
}
bool CUDAExecutor::HostCallback(Stream *stream,
- std::function<void()> callback) {
- auto callback_ptr = new std::function<void()>(callback);
+ std::function<port::Status()> callback) {
+ auto callback_ptr = new std::function<void()>([callback]() {
+ port::Status s = callback();
+ if (!s.ok()) {
+ LOG(WARNING) << "Host callback failed: " << s;
+ }
+ });
return CUDADriver::AddStreamCallback(context_, AsCUDAStreamValue(stream),
InternalHostCallback, callback_ptr);
}
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index 90bf1c0..ae8e4ab 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -148,7 +148,8 @@
const DeviceMemoryBase &gpu_src,
uint64 size) override;
- bool HostCallback(Stream *stream, std::function<void()> callback) override;
+ bool HostCallback(Stream *stream,
+ std::function<port::Status()> callback) override;
bool AllocateStream(Stream *stream) override;
diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc
index 4120e23..0b991b7 100644
--- a/tensorflow/stream_executor/device_description.cc
+++ b/tensorflow/stream_executor/device_description.cc
@@ -140,21 +140,11 @@
uint64 element_count, uint64 *threads_per_block,
uint64 *block_count) {
*threads_per_block = device_description.threads_per_block_limit();
- *block_count = DivideCeil(element_count, *threads_per_block);
+ *block_count = port::MathUtil::CeilOfRatio(element_count, *threads_per_block);
if (*block_count == 1) {
CHECK_LE(element_count, *threads_per_block);
*threads_per_block = element_count;
}
}
-// Round value up to a multiple of n.
-static uint64 RoundUp(uint64 value, uint64 n) {
- return port::MathUtil::CeilOfRatio(value, n) * n;
-}
-
-// Round value down to a multiple of n.
-static uint64 RoundDown(uint64 value, uint64 n) {
- return port::MathUtil::FloorOfRatio(value, n) * n;
-}
-
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc
index 8adf739..1396a83 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.cc
+++ b/tensorflow/stream_executor/host/host_gpu_executor.cc
@@ -148,8 +148,13 @@
}
bool HostExecutor::HostCallback(Stream *stream,
- std::function<void()> callback) {
- AsHostStream(stream)->EnqueueTask(callback);
+ std::function<port::Status()> callback) {
+ AsHostStream(stream)->EnqueueTask([callback]() {
+ port::Status s = callback();
+ if (!s.ok()) {
+ LOG(WARNING) << "Host callback failed: " << s;
+ }
+ });
return true;
}
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index 7ba1f18..56e3c2a 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -103,7 +103,8 @@
const DeviceMemoryBase &gpu_src,
uint64 size) override;
- bool HostCallback(Stream *stream, std::function<void()> callback) override;
+ bool HostCallback(Stream *stream,
+ std::function<port::Status()> callback) override;
port::Status AllocateEvent(Event *event) override {
return port::Status(port::error::UNIMPLEMENTED, "");
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index e1629b5..0fc90cf 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -2034,8 +2034,19 @@
internal::StreamInterface *implementation() { return implementation_.get(); }
// Entrains onto the stream a callback to the host (from the device).
+ // Behaves as ThenDoHostCallbackWithStatus below, but the callback should
+ // never fail or its failure is inconsequential.
+ //
+ // This is kept for backward compatibility. Future code should use
+ // ThenDoHostCallbackWithStatus and explicitly return a success status.
+ // TODO(b/112125301): Eventually remove this method.
+ Stream &ThenDoHostCallback(std::function<void()> callback);
+
+ // Entrains onto the stream a callback to the host (from the device).
// Host callbacks block/occupy the stream just as device functions
// (execute one at a time, block later stream operations).
+ // Whether the callback return status affects the result of BlockHostUntilDone
+ // is platform-dependent.
//
// Behavior is undefined when synchronizing using OpenCL user events.
// Behavior is undefined if host callbacks call device routines or insert
@@ -2043,11 +2054,6 @@
//
// On certain platforms, ThenDoHostCallback is expected to have significant
// negative effects on performance.
- Stream &ThenDoHostCallback(std::function<void()> callback);
-
- // Entrains onto the stream a callback to the host (from the device).
- // Behaves as ThenDoHostCallback above, but returns a Status instead of void.
- // This overload should be preferred if the callback could fail.
Stream &ThenDoHostCallbackWithStatus(std::function<port::Status()> callback);
// Returns the StreamExecutor (parent object) associated with this stream.
diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc
index 7df6a36..341c6ed 100644
--- a/tensorflow/stream_executor/stream_executor_internal.cc
+++ b/tensorflow/stream_executor/stream_executor_internal.cc
@@ -36,16 +36,15 @@
StreamExecutorFactory MakeHostExecutorImplementation;
-// TODO(b/112125301): Consolodate this down to one implementation of
-// HostCallback, taking a callback that returns a Status.
-bool StreamExecutorInterface::HostCallback(
- Stream* stream, std::function<port::Status()> callback) {
- return HostCallback(stream, [callback]() {
- port::Status s = callback();
- if (!s.ok()) {
- LOG(WARNING) << "HostCallback failed: " << s;
- }
- });
+// The default implementation just calls the other HostCallback method.
+// It should make all existing code that uses a void() callback still work.
+bool StreamExecutorInterface::HostCallback(Stream* stream,
+ std::function<void()> callback) {
+ return HostCallback(
+ stream, std::function<port::Status()>([callback]() -> port::Status {
+ callback();
+ return port::Status::OK();
+ }));
}
} // namespace internal
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 32f75fd..0c2c33c 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -237,9 +237,9 @@
virtual bool MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *gpu_dst,
const DeviceMemoryBase &gpu_src,
uint64 size) = 0;
- virtual bool HostCallback(Stream *stream, std::function<void()> callback) = 0;
+ virtual bool HostCallback(Stream *stream, std::function<void()> callback);
virtual bool HostCallback(Stream *stream,
- std::function<port::Status()> callback);
+ std::function<port::Status()> callback) = 0;
virtual port::Status AllocateEvent(Event *event) = 0;
virtual port::Status DeallocateEvent(Event *event) = 0;
virtual port::Status RecordEvent(Stream *stream, Event *event) = 0;
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 2d67d1f..4bc6844 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -203,8 +203,12 @@
"//conditions:default": [],
})
-def if_not_tx2_llvm_or_windows_cuda(a):
- return if_not_windows_cuda(a)
+def if_nccl(a):
+ return select({
+ "//tensorflow:no_nccl_support": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": a,
+ })
def get_win_copts(is_external = False):
WINDOWS_COPTS = [
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
index 0a16d6a..e37d299 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-gradient-tape.pbtxt
@@ -7,10 +7,18 @@
argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
+ name: "batch_jacobian"
+ argspec: "args=[\'self\', \'target\', \'source\', \'unconnected_gradients\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'True\'], "
+ }
+ member_method {
name: "gradient"
argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'UnconnectedGradients.NONE\'], "
}
member_method {
+ name: "jacobian"
+ argspec: "args=[\'self\', \'target\', \'sources\', \'unconnected_gradients\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'True\'], "
+ }
+ member_method {
name: "reset"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt
index 9d032d4..024b451 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.data.Options"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
+ is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_autotune"
@@ -54,6 +55,10 @@
name: "experimental_stats"
mtype: "<type \'property\'>"
}
+ member {
+ name: "experimental_threading"
+ mtype: "<type \'property\'>"
+ }
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-options.pbtxt
index f423eed..892f8c1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-options.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.data.experimental.StatsOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_options.StatsOptions\'>"
+ is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "aggregator"
@@ -20,6 +21,6 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'aggregator\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-threading-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-threading-options.pbtxt
new file mode 100644
index 0000000..5b5ebf1
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-threading-options.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.data.experimental.ThreadingOptions"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.threading_options.ThreadingOptions\'>"
+ is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "max_intra_op_parallelism"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "private_threadpool_size"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
index 244b245..7bc3faa 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
@@ -40,6 +40,10 @@
name: "TFRecordWriter"
mtype: "<type \'type\'>"
}
+ member {
+ name: "ThreadingOptions"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt
index c39ac5a..583cbc6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.distribute.InputContext"
tf_class {
- is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
+ is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputContext\'>"
is_instance: "<type \'object\'>"
member {
name: "input_pipeline_id"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
index 3eda6c6..df707e8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.distribute.ReplicaContext"
tf_class {
- is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
+ is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
is_instance: "<type \'object\'>"
member {
name: "devices"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
index 3b502b5..77706e5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.distribute.StrategyExtended"
tf_class {
- is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
+ is_instance: "<class \'tensorflow.python.distribute.distribute_lib.DistributionStrategyExtended\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_between_graph"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt
index 4fe035b..0fd9a3b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.distribute.Strategy"
tf_class {
- is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
+ is_instance: "<class \'tensorflow.python.distribute.distribute_lib.DistributionStrategy\'>"
is_instance: "<type \'object\'>"
member {
name: "between_graph"
@@ -123,6 +123,10 @@
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
+ name: "update_config_proto"
+ argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt
index af16595..2257425 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-classifier.pbtxt
@@ -34,8 +34,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-estimator.pbtxt
index d218773..38b27f7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-estimator.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt
index e579425..5c51767 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-baseline-regressor.pbtxt
@@ -34,8 +34,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 970abd8..e138ce9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -34,6 +34,10 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "experimental_feature_importances"
argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
@@ -43,7 +47,7 @@
}
member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index b5bbad9..eae0a29 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -34,6 +34,10 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "experimental_feature_importances"
argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
@@ -43,7 +47,7 @@
}
member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt
index 77e60d4..a142ca3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-classifier.pbtxt
@@ -34,8 +34,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-estimator.pbtxt
index 85ff5a4..09e0d38 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-estimator.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
index 07aefed..85a2082 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
@@ -34,8 +34,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-estimator.pbtxt
index ac13dad..e311f96 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-estimator.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
index 852e8d2..e05c7ce 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
@@ -34,8 +34,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt
index 2779cbe..fc3b1d9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-d-n-n-regressor.pbtxt
@@ -34,8 +34,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt
index eee5746..bff6c86 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-estimator.pbtxt
@@ -32,8 +32,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt
index 6569e92..d213551 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-classifier.pbtxt
@@ -34,8 +34,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-estimator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-estimator.pbtxt
index 023edec..2148374 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-estimator.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt
index d74bf4f..004dfcc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-linear-regressor.pbtxt
@@ -34,8 +34,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.experimental.-in-memory-evaluator-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.experimental.-in-memory-evaluator-hook.pbtxt
new file mode 100644
index 0000000..aba1202
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.experimental.-in-memory-evaluator-hook.pbtxt
@@ -0,0 +1,30 @@
+path: "tensorflow.estimator.experimental.InMemoryEvaluatorHook"
+tf_class {
+ is_instance: "<class \'tensorflow_estimator.python.estimator.hooks.InMemoryEvaluatorHook\'>"
+ is_instance: "<class \'tensorflow.python.training.session_run_hook.SessionRunHook\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'estimator\', \'input_fn\', \'steps\', \'hooks\', \'name\', \'every_n_iter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'100\'], "
+ }
+ member_method {
+ name: "after_create_session"
+ argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_run"
+ argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "before_run"
+ argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "begin"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "end"
+ argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.experimental.pbtxt
index cabca3e..2a9a034 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.experimental.pbtxt
@@ -1,10 +1,18 @@
path: "tensorflow.estimator.experimental"
tf_module {
member {
+ name: "InMemoryEvaluatorHook"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "LinearSDCA"
mtype: "<type \'type\'>"
}
member_method {
+ name: "build_raw_supervised_input_receiver_fn"
+ argspec: "args=[\'features\', \'labels\', \'default_batch_size\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "call_logit_fn"
argspec: "args=[\'logit_fn\', \'features\', \'mode\', \'params\', \'config\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.gfile.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.gfile.pbtxt
new file mode 100644
index 0000000..e5aba7e
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.gfile.pbtxt
@@ -0,0 +1,51 @@
+path: "tensorflow.io.gfile"
+tf_module {
+ member_method {
+ name: "copy"
+ argspec: "args=[\'src\', \'dst\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "exists"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "glob"
+ argspec: "args=[\'pattern\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "isdir"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "listdir"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "makedirs"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "mkdir"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "remove"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "rename"
+ argspec: "args=[\'src\', \'dst\', \'overwrite\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "rmtree"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "stat"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "walk"
+ argspec: "args=[\'top\', \'topdown\', \'onerror\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
index 64b63ed..b760ec3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
@@ -44,11 +44,23 @@
name: "VarLenFeature"
mtype: "<type \'type\'>"
}
+ member {
+ name: "gfile"
+ mtype: "<type \'module\'>"
+ }
+ member_method {
+ name: "decode_and_crop_jpeg"
+ argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
+ }
member_method {
name: "decode_base64"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "decode_bmp"
+ argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
+ }
+ member_method {
name: "decode_compressed"
argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
@@ -57,10 +69,26 @@
argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\', \'\', \'None\'], "
}
member_method {
+ name: "decode_gif"
+ argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "decode_image"
+ argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'uint8\'>\", \'None\'], "
+ }
+ member_method {
+ name: "decode_jpeg"
+ argspec: "args=[\'contents\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
+ }
+ member_method {
name: "decode_json_example"
argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "decode_png"
+ argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'None\'], "
+ }
+ member_method {
name: "decode_raw"
argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -73,6 +101,18 @@
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "encode_jpeg"
+ argspec: "args=[\'image\', \'format\', \'quality\', \'progressive\', \'optimize_size\', \'chroma_downsampling\', \'density_unit\', \'x_density\', \'y_density\', \'xmp_metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'95\', \'False\', \'False\', \'True\', \'in\', \'300\', \'300\', \'\', \'None\'], "
+ }
+ member_method {
+ name: "extract_jpeg_shape"
+ argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
+ name: "is_jpeg"
+ argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "match_filenames_once"
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
index 8200345..b3d3c84 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.BatchNormalization"
tf_class {
- is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
+ is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV1\'>"
+ is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV2\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt
index 5fd0a47..bc3ceb6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.keras.layers.InputSpec"
tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.InputSpec\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.input_spec.InputSpec\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-mean-squared-error.pbtxt
new file mode 100644
index 0000000..a571853
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.-mean-squared-error.pbtxt
@@ -0,0 +1,22 @@
+path: "tensorflow.keras.losses.MeanSquaredError"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.losses.MeanSquaredError\'>"
+ is_instance: "<class \'tensorflow.python.keras.losses.Loss\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'sum_over_batch_size\', \'None\'], "
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt
index eca6b91..a0af6a2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.losses.pbtxt
@@ -1,5 +1,9 @@
path: "tensorflow.keras.losses"
tf_module {
+ member {
+ name: "MeanSquaredError"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "KLD"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt
index f53567a..2db07df 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.Accuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Accuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt
index f53567a..904ad3a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-binary-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.BinaryAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.BinaryAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\', \'threshold\'], varargs=None, keywords=None, defaults=[\'binary_accuracy\', \'None\', \'0.5\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt
index f53567a..17b7492 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-categorical-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.CategoricalAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.CategoricalAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'categorical_accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt
index f53567a..49f577e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-negatives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.FalseNegatives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.FalseNegatives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt
index f53567a..e8baf85 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-false-positives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.FalsePositives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.FalsePositives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt
index f53567a..40fe64b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-mean.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.Mean"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'values\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt
index f53567a..ae6a850 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-precision.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.Precision"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Precision\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt
index f53567a..31068a5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-recall.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.Recall"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Recall\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt
index f53567a..0c17452 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.SparseCategoricalAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.SparseCategoricalAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'sparse_categorical_accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt
index f53567a..1b5eb8d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-negatives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.TrueNegatives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.TrueNegatives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt
index f53567a..5b9c470 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.-true-positives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.TruePositives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.TruePositives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
index a296e13..cc90d0e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.metrics.pbtxt
@@ -1,5 +1,49 @@
path: "tensorflow.keras.metrics"
tf_module {
+ member {
+ name: "Accuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "BinaryAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "CategoricalAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "FalseNegatives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "FalsePositives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Mean"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Precision"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Recall"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SparseCategoricalAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TrueNegatives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TruePositives"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "KLD"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt
index b132bd4..16d9ecc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-batch-normalization.pbtxt
@@ -1,7 +1,8 @@
path: "tensorflow.layers.BatchNormalization"
tf_class {
is_instance: "<class \'tensorflow.python.layers.normalization.BatchNormalization\'>"
- is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
+ is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV1\'>"
+ is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV2\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.-input-spec.pbtxt
index fd02c91..80834e0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.layers.-input-spec.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.-input-spec.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.layers.InputSpec"
tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.InputSpec\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.input_spec.InputSpec\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.losses.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.losses.-mean-squared-error.pbtxt
new file mode 100644
index 0000000..a626d9c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.losses.-mean-squared-error.pbtxt
@@ -0,0 +1,22 @@
+path: "tensorflow.losses.MeanSquaredError"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.losses.MeanSquaredError\'>"
+ is_instance: "<class \'tensorflow.python.keras.losses.Loss\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'sum_over_batch_size\', \'None\'], "
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt
index b2adb52..258ad50 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.losses.-reduction.pbtxt
@@ -1,7 +1,6 @@
path: "tensorflow.losses.Reduction"
tf_class {
is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.Reduction\'>"
- is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.ReductionV2\'>"
is_instance: "<type \'object\'>"
member {
name: "MEAN"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.losses.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.losses.pbtxt
index c1d190a..a198db1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.losses.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.losses.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.losses"
tf_module {
member {
+ name: "MeanSquaredError"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Reduction"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
index 67f348b..f34e2c2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
@@ -318,7 +318,7 @@
}
member_method {
name: "reduce_std"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_sum"
@@ -326,7 +326,7 @@
}
member_method {
name: "reduce_variance"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "rint"
@@ -342,7 +342,7 @@
}
member_method {
name: "scalar_mul"
- argspec: "args=[\'scalar\', \'x\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'scalar\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "segment_max"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-accuracy.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-accuracy.pbtxt
index f53567a..f8e12f8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.Accuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Accuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-binary-accuracy.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-binary-accuracy.pbtxt
index f53567a..b9bc6a7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-binary-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.BinaryAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.BinaryAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\', \'threshold\'], varargs=None, keywords=None, defaults=[\'binary_accuracy\', \'None\', \'0.5\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-categorical-accuracy.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-categorical-accuracy.pbtxt
index f53567a..0ef75d8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-categorical-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.CategoricalAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.CategoricalAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'categorical_accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-false-negatives.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-false-negatives.pbtxt
index f53567a..33226a2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-false-negatives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.FalseNegatives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.FalseNegatives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-false-positives.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-false-positives.pbtxt
index f53567a..9953162 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-false-positives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.FalsePositives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.FalsePositives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-mean.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-mean.pbtxt
index f53567a..7fe6d6f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-mean.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.Mean"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'values\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-precision.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-precision.pbtxt
index f53567a..8c3271a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-precision.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.Precision"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Precision\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-recall.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-recall.pbtxt
index f53567a..840a68b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-recall.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.Recall"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Recall\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt
index f53567a..7bce43f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.SparseCategoricalAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.SparseCategoricalAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'sparse_categorical_accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-true-negatives.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-true-negatives.pbtxt
index f53567a..83cd5b7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-true-negatives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.TrueNegatives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.TrueNegatives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-true-positives.pbtxt
similarity index 74%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.metrics.-true-positives.pbtxt
index f53567a..5b2502e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.-true-positives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.TruePositives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.TruePositives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt
index e9b996c..f5c267a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.metrics.pbtxt
@@ -1,5 +1,49 @@
path: "tensorflow.metrics"
tf_module {
+ member {
+ name: "Accuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "BinaryAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "CategoricalAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "FalseNegatives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "FalsePositives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Mean"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Precision"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Recall"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SparseCategoricalAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TrueNegatives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TruePositives"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "accuracy"
argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
index e781287..48501e1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
@@ -45,6 +45,10 @@
argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'sequence_length\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'parallel_iterations\', \'swap_memory\', \'time_major\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'None\'], "
}
member_method {
+ name: "collapse_repeated"
+ argspec: "args=[\'labels\', \'seq_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "compute_accidental_hits"
argspec: "args=[\'true_classes\', \'sampled_candidates\', \'num_true\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
@@ -109,6 +113,14 @@
argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'True\'], "
}
member_method {
+ name: "ctc_loss_v2"
+ argspec: "args=[\'labels\', \'logits\', \'label_length\', \'logit_length\', \'logits_time_major\', \'unique\', \'blank_index\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "ctc_unique_labels"
+ argspec: "args=[\'labels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "depth_to_space"
argspec: "args=[\'input\', \'block_size\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'NHWC\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 6a45bc7..a294e3e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -701,6 +701,10 @@
argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
}
member_method {
+ name: "argsort"
+ argspec: "args=[\'values\', \'axis\', \'direction\', \'stable\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'ASCENDING\', \'False\', \'None\'], "
+ }
+ member_method {
name: "as_dtype"
argspec: "args=[\'type_value\'], varargs=None, keywords=None, defaults=None"
}
@@ -1241,6 +1245,10 @@
argspec: "args=[\'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'collections\', \'caching_device\', \'partitioner\', \'validate_shape\', \'use_resource\', \'custom_getter\', \'constraint\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'None\', \'True\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
+ name: "get_logger"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "get_seed"
argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None"
}
@@ -1474,7 +1482,7 @@
}
member_method {
name: "make_tensor_proto"
- argspec: "args=[\'values\', \'dtype\', \'shape\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
+ argspec: "args=[\'values\', \'dtype\', \'shape\', \'verify_shape\', \'allow_broadcast\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], "
}
member_method {
name: "map_fn"
@@ -1818,7 +1826,7 @@
}
member_method {
name: "scalar_mul"
- argspec: "args=[\'scalar\', \'x\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'scalar\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "scan"
@@ -1957,6 +1965,10 @@
argspec: "args=[\'input_\', \'begin\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "sort"
+ argspec: "args=[\'values\', \'axis\', \'direction\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'ASCENDING\', \'None\'], "
+ }
+ member_method {
name: "space_to_batch"
argspec: "args=[\'input\', \'paddings\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
index 2948b73..632c2f8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
@@ -34,7 +34,7 @@
}
member_method {
name: "quantize_and_dequantize"
- argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'None\'], "
+ argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\'], "
}
member_method {
name: "quantized_concat"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt
index df528e2..6fc489c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.test.-benchmark.pbtxt
@@ -7,6 +7,10 @@
name: "__init__"
}
member_method {
+ name: "evaluate"
+ argspec: "args=[\'self\', \'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "is_abstract"
argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
index 877c55c..bdb3ea2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
@@ -397,6 +397,10 @@
argspec: "args=[\'x\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "piecewise_constant_decay"
+ argspec: "args=[\'x\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "polynomial_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
index 0a16d6a..e37d299 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-gradient-tape.pbtxt
@@ -7,10 +7,18 @@
argspec: "args=[\'self\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
+ name: "batch_jacobian"
+ argspec: "args=[\'self\', \'target\', \'source\', \'unconnected_gradients\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'True\'], "
+ }
+ member_method {
name: "gradient"
argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'UnconnectedGradients.NONE\'], "
}
member_method {
+ name: "jacobian"
+ argspec: "args=[\'self\', \'target\', \'sources\', \'unconnected_gradients\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'True\'], "
+ }
+ member_method {
name: "reset"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt
deleted file mode 100644
index 4f0147a..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-iterator.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.data.Iterator"
-tf_class {
- is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_classes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shapes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_types"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'iterator_resource\', \'initializer\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "from_string_handle"
- argspec: "args=[\'string_handle\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "from_structure"
- argspec: "args=[\'output_types\', \'output_shapes\', \'shared_name\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "get_next"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "make_initializer"
- argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "string_handle"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt
index 9d032d4..024b451 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.data.Options"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
+ is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_autotune"
@@ -54,6 +55,10 @@
name: "experimental_stats"
mtype: "<type \'property\'>"
}
+ member {
+ name: "experimental_threading"
+ mtype: "<type \'property\'>"
+ }
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-options.pbtxt
index f423eed..892f8c1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-options.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.data.experimental.StatsOptions"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_options.StatsOptions\'>"
+ is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
is_instance: "<type \'object\'>"
member {
name: "aggregator"
@@ -20,6 +21,6 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'aggregator\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-threading-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-threading-options.pbtxt
new file mode 100644
index 0000000..5b5ebf1
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-threading-options.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.data.experimental.ThreadingOptions"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.threading_options.ThreadingOptions\'>"
+ is_instance: "<class \'tensorflow.python.data.util.options.OptionsBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "max_intra_op_parallelism"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "private_threadpool_size"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
index 244b245..7bc3faa 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
@@ -40,6 +40,10 @@
name: "TFRecordWriter"
mtype: "<type \'type\'>"
}
+ member {
+ name: "ThreadingOptions"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
index 509bbae..4c3d6dd 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
@@ -9,10 +9,6 @@
mtype: "<type \'type\'>"
}
member {
- name: "Iterator"
- mtype: "<type \'type\'>"
- }
- member {
name: "Options"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt
index c39ac5a..583cbc6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.distribute.InputContext"
tf_class {
- is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
+ is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputContext\'>"
is_instance: "<type \'object\'>"
member {
name: "input_pipeline_id"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
index 3eda6c6..df707e8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.distribute.ReplicaContext"
tf_class {
- is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
+ is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ReplicaContext\'>"
is_instance: "<type \'object\'>"
member {
name: "devices"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
index 3b502b5..77706e5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.distribute.StrategyExtended"
tf_class {
- is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
+ is_instance: "<class \'tensorflow.python.distribute.distribute_lib.DistributionStrategyExtended\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_between_graph"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt
index 4fe035b..0fd9a3b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.distribute.Strategy"
tf_class {
- is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
+ is_instance: "<class \'tensorflow.python.distribute.distribute_lib.DistributionStrategy\'>"
is_instance: "<type \'object\'>"
member {
name: "between_graph"
@@ -123,6 +123,10 @@
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
+ name: "update_config_proto"
+ argspec: "args=[\'self\', \'config_proto\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt
index 07483df..22cbcf0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-classifier.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-estimator.pbtxt
index d218773..38b27f7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-estimator.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt
index 292b5f3..a965042 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-baseline-regressor.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 970abd8..e138ce9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -34,6 +34,10 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "experimental_feature_importances"
argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
@@ -43,7 +47,7 @@
}
member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index b5bbad9..eae0a29 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -34,6 +34,10 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "experimental_feature_importances"
argspec: "args=[\'self\', \'normalize\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
@@ -43,7 +47,7 @@
}
member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt
index c542edf..f6bd4d2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-classifier.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-estimator.pbtxt
index 85ff5a4..09e0d38 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-estimator.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
index 623cbc3..60627cc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-estimator.pbtxt
index ac13dad..e311f96 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-estimator.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
index f45e765..dc6aca2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt
index 8db2196..7338abc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-d-n-n-regressor.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt
index 71531fd..a1f0e76 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-estimator.pbtxt
@@ -31,8 +31,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt
index 72c226b..6559c58 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-classifier.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-estimator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-estimator.pbtxt
index 023edec..2148374 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-estimator.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt
index c4bb196..e6ea074 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-linear-regressor.pbtxt
@@ -33,8 +33,12 @@
argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "experimental_export_all_saved_models"
+ argspec: "args=[\'self\', \'export_dir_base\', \'input_receiver_fn_map\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
name: "export_saved_model"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'experimental_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'infer\'], "
}
member_method {
name: "export_savedmodel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.experimental.-in-memory-evaluator-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.experimental.-in-memory-evaluator-hook.pbtxt
new file mode 100644
index 0000000..aba1202
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.experimental.-in-memory-evaluator-hook.pbtxt
@@ -0,0 +1,30 @@
+path: "tensorflow.estimator.experimental.InMemoryEvaluatorHook"
+tf_class {
+ is_instance: "<class \'tensorflow_estimator.python.estimator.hooks.InMemoryEvaluatorHook\'>"
+ is_instance: "<class \'tensorflow.python.training.session_run_hook.SessionRunHook\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'estimator\', \'input_fn\', \'steps\', \'hooks\', \'name\', \'every_n_iter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'100\'], "
+ }
+ member_method {
+ name: "after_create_session"
+ argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_run"
+ argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "before_run"
+ argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "begin"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "end"
+ argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.experimental.pbtxt
index cabca3e..2a9a034 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.experimental.pbtxt
@@ -1,10 +1,18 @@
path: "tensorflow.estimator.experimental"
tf_module {
member {
+ name: "InMemoryEvaluatorHook"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "LinearSDCA"
mtype: "<type \'type\'>"
}
member_method {
+ name: "build_raw_supervised_input_receiver_fn"
+ argspec: "args=[\'features\', \'labels\', \'default_batch_size\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "call_logit_fn"
argspec: "args=[\'logit_fn\', \'features\', \'mode\', \'params\', \'config\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
index f6e165b..3aadd7d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.feature_column.pbtxt
@@ -14,7 +14,7 @@
}
member_method {
name: "categorical_column_with_vocabulary_file"
- argspec: "args=[\'key\', \'vocabulary_file\', \'vocabulary_size\', \'num_oov_buckets\', \'default_value\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\', \"<dtype: \'string\'>\"], "
+ argspec: "args=[\'key\', \'vocabulary_file\', \'vocabulary_size\', \'dtype\', \'default_value\', \'num_oov_buckets\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\", \'None\', \'0\'], "
}
member_method {
name: "categorical_column_with_vocabulary_list"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.gfile.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.gfile.pbtxt
deleted file mode 100644
index 74d0a05..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.gfile.pbtxt
+++ /dev/null
@@ -1,47 +0,0 @@
-path: "tensorflow.gfile"
-tf_module {
- member_method {
- name: "Copy"
- argspec: "args=[\'oldpath\', \'newpath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "DeleteRecursively"
- argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "Glob"
- argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "IsDirectory"
- argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "ListDirectory"
- argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "MakeDirs"
- argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "MkDir"
- argspec: "args=[\'dirname\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "Remove"
- argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "Rename"
- argspec: "args=[\'oldname\', \'newname\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "Stat"
- argspec: "args=[\'filename\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "Walk"
- argspec: "args=[\'top\', \'in_order\'], varargs=None, keywords=None, defaults=[\'True\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
index f25fb65..3c6ed1c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
@@ -38,7 +38,7 @@
}
member_method {
name: "crop_and_resize"
- argspec: "args=[\'image\', \'boxes\', \'box_ind\', \'crop_size\', \'method\', \'extrapolation_value\', \'name\'], varargs=None, keywords=None, defaults=[\'bilinear\', \'0\', \'None\'], "
+ argspec: "args=[\'image\', \'boxes\', \'box_indices\', \'crop_size\', \'method\', \'extrapolation_value\', \'name\'], varargs=None, keywords=None, defaults=[\'bilinear\', \'0\', \'None\'], "
}
member_method {
name: "crop_to_bounding_box"
@@ -86,7 +86,7 @@
}
member_method {
name: "extract_image_patches"
- argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'images\', \'sizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "extract_jpeg_shape"
@@ -206,7 +206,7 @@
}
member_method {
name: "sample_distorted_bounding_box"
- argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "sobel_edges"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.gfile.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.gfile.pbtxt
index 59652cb..e5aba7e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.io.gfile.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.gfile.pbtxt
@@ -1,7 +1,51 @@
path: "tensorflow.io.gfile"
tf_module {
member_method {
+ name: "copy"
+ argspec: "args=[\'src\', \'dst\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
name: "exists"
argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
}
+ member_method {
+ name: "glob"
+ argspec: "args=[\'pattern\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "isdir"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "listdir"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "makedirs"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "mkdir"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "remove"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "rename"
+ argspec: "args=[\'src\', \'dst\', \'overwrite\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "rmtree"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "stat"
+ argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "walk"
+ argspec: "args=[\'top\', \'topdown\', \'onerror\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
index b27df17..8906329 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
@@ -49,22 +49,46 @@
mtype: "<type \'module\'>"
}
member_method {
+ name: "decode_and_crop_jpeg"
+ argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
+ }
+ member_method {
name: "decode_base64"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "decode_bmp"
+ argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
+ }
+ member_method {
name: "decode_compressed"
argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
name: "decode_csv"
- argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\', \'\', \'None\'], "
+ argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\', \'name\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "decode_gif"
+ argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "decode_image"
+ argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'uint8\'>\", \'None\'], "
+ }
+ member_method {
+ name: "decode_jpeg"
+ argspec: "args=[\'contents\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
}
member_method {
name: "decode_json_example"
argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "decode_png"
+ argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'None\'], "
+ }
+ member_method {
name: "decode_raw"
argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
@@ -77,6 +101,18 @@
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "encode_jpeg"
+ argspec: "args=[\'image\', \'format\', \'quality\', \'progressive\', \'optimize_size\', \'chroma_downsampling\', \'density_unit\', \'x_density\', \'y_density\', \'xmp_metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'95\', \'False\', \'False\', \'True\', \'in\', \'300\', \'300\', \'\', \'None\'], "
+ }
+ member_method {
+ name: "extract_jpeg_shape"
+ argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
+ name: "is_jpeg"
+ argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "match_filenames_once"
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -86,7 +122,7 @@
}
member_method {
name: "parse_example"
- argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'serialized\', \'features\', \'example_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "parse_sequence_example"
@@ -94,7 +130,7 @@
}
member_method {
name: "parse_single_example"
- argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'serialized\', \'features\', \'example_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "parse_single_sequence_example"
@@ -110,21 +146,17 @@
}
member_method {
name: "serialize_many_sparse"
- argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
+ argspec: "args=[\'sp_input\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'string\'>\", \'None\'], "
}
member_method {
name: "serialize_sparse"
- argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
+ argspec: "args=[\'sp_input\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'string\'>\", \'None\'], "
}
member_method {
name: "serialize_tensor"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "tf_record_iterator"
- argspec: "args=[\'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
name: "write_file"
argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
index 8200345..5da7926 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.keras.layers.BatchNormalization"
tf_class {
- is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalization\'>"
+ is_instance: "<class \'tensorflow.python.keras.layers.normalization.BatchNormalizationV2\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt
index 5fd0a47..bc3ceb6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.keras.layers.InputSpec"
tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.InputSpec\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.input_spec.InputSpec\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-mean-squared-error.pbtxt
new file mode 100644
index 0000000..a571853
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-mean-squared-error.pbtxt
@@ -0,0 +1,22 @@
+path: "tensorflow.keras.losses.MeanSquaredError"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.losses.MeanSquaredError\'>"
+ is_instance: "<class \'tensorflow.python.keras.losses.Loss\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'sum_over_batch_size\', \'None\'], "
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-reduction.pbtxt
new file mode 100644
index 0000000..031d9b1
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.-reduction.pbtxt
@@ -0,0 +1,28 @@
+path: "tensorflow.keras.losses.Reduction"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.utils.losses_utils.ReductionV2\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "NONE"
+ mtype: "<type \'NoneType\'>"
+ }
+ member {
+ name: "SUM"
+ mtype: "<type \'str\'>"
+ }
+ member {
+ name: "SUM_OVER_BATCH_SIZE"
+ mtype: "<type \'str\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "all"
+ argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "validate"
+ argspec: "args=[\'cls\', \'key\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt
index eca6b91..cb156e2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.losses.pbtxt
@@ -1,5 +1,13 @@
path: "tensorflow.keras.losses"
tf_module {
+ member {
+ name: "MeanSquaredError"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Reduction"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "KLD"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt
index f53567a..2db07df 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.Accuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Accuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt
index f53567a..904ad3a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-binary-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.BinaryAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.BinaryAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\', \'threshold\'], varargs=None, keywords=None, defaults=[\'binary_accuracy\', \'None\', \'0.5\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt
index f53567a..17b7492 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-categorical-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.CategoricalAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.CategoricalAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'categorical_accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt
index f53567a..49f577e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-negatives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.FalseNegatives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.FalseNegatives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt
index f53567a..e8baf85 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-false-positives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.FalsePositives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.FalsePositives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt
index f53567a..40fe64b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-mean.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.Mean"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'values\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt
index f53567a..ae6a850 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-precision.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.Precision"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Precision\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt
index f53567a..31068a5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-recall.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.Recall"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Recall\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt
index f53567a..0c17452 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-sparse-categorical-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.SparseCategoricalAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.SparseCategoricalAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'sparse_categorical_accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt
index f53567a..1b5eb8d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-negatives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.TrueNegatives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.TrueNegatives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt
index f53567a..5b9c470 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.-true-positives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.keras.metrics.TruePositives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.TruePositives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
index a296e13..cc90d0e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.metrics.pbtxt
@@ -1,5 +1,49 @@
path: "tensorflow.keras.metrics"
tf_module {
+ member {
+ name: "Accuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "BinaryAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "CategoricalAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "FalseNegatives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "FalsePositives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Mean"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Precision"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "Recall"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SparseCategoricalAccuracy"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TrueNegatives"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TruePositives"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "KLD"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
index 1a4098d..a3599bf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
@@ -118,7 +118,7 @@
}
member_method {
name: "l2_normalize"
- argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], "
+ argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\'], "
}
member_method {
name: "logdet"
@@ -142,7 +142,7 @@
}
member_method {
name: "norm"
- argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "qr"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.-mean-squared-error.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.-mean-squared-error.pbtxt
new file mode 100644
index 0000000..a626d9c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.-mean-squared-error.pbtxt
@@ -0,0 +1,22 @@
+path: "tensorflow.losses.MeanSquaredError"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.losses.MeanSquaredError\'>"
+ is_instance: "<class \'tensorflow.python.keras.losses.Loss\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'sum_over_batch_size\', \'None\'], "
+ }
+ member_method {
+ name: "call"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt
index 6a44e4c..ad72e31 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.-reduction.pbtxt
@@ -1,10 +1,10 @@
path: "tensorflow.losses.Reduction"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.ReductionV2\'>"
+ is_instance: "<class \'tensorflow.python.keras.utils.losses_utils.ReductionV2\'>"
is_instance: "<type \'object\'>"
member {
name: "NONE"
- mtype: "<type \'str\'>"
+ mtype: "<type \'NoneType\'>"
}
member {
name: "SUM"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt
index c1d190a..87f5ef3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.losses.pbtxt
@@ -1,26 +1,18 @@
path: "tensorflow.losses"
tf_module {
member {
+ name: "MeanSquaredError"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "Reduction"
mtype: "<type \'type\'>"
}
member_method {
- name: "absolute_difference"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
- member_method {
name: "add_loss"
argspec: "args=[\'loss\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'losses\'], "
}
member_method {
- name: "compute_weighted_loss"
- argspec: "args=[\'losses\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
- member_method {
- name: "cosine_distance"
- argspec: "args=[\'labels\', \'predictions\', \'axis\', \'weights\', \'scope\', \'loss_collection\', \'reduction\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\', \'None\'], "
- }
- member_method {
name: "get_losses"
argspec: "args=[\'scope\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'None\', \'losses\'], "
}
@@ -36,36 +28,4 @@
name: "get_total_loss"
argspec: "args=[\'add_regularization_losses\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'total_loss\'], "
}
- member_method {
- name: "hinge_loss"
- argspec: "args=[\'labels\', \'logits\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
- member_method {
- name: "huber_loss"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'delta\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
- member_method {
- name: "log_loss"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'epsilon\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'1e-07\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
- member_method {
- name: "mean_pairwise_squared_error"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\'], "
- }
- member_method {
- name: "mean_squared_error"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
- member_method {
- name: "sigmoid_cross_entropy"
- argspec: "args=[\'multi_class_labels\', \'logits\', \'weights\', \'label_smoothing\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
- member_method {
- name: "softmax_cross_entropy"
- argspec: "args=[\'onehot_labels\', \'logits\', \'weights\', \'label_smoothing\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
- member_method {
- name: "sparse_softmax_cross_entropy"
- argspec: "args=[\'labels\', \'logits\', \'weights\', \'scope\', \'loss_collection\', \'reduction\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\', \'losses\', \'weighted_sum_by_nonzero_weights\'], "
- }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
index e6b8fd2..979d77e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
@@ -78,7 +78,7 @@
}
member_method {
name: "bincount"
- argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\"], "
+ argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\", \'None\'], "
}
member_method {
name: "ceil"
@@ -86,7 +86,7 @@
}
member_method {
name: "confusion_matrix"
- argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'dtype\', \'name\', \'weights\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int32\'>\", \'None\', \'None\'], "
+ argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'weights\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'int32\'>\", \'None\'], "
}
member_method {
name: "conj"
@@ -198,7 +198,7 @@
}
member_method {
name: "l2_normalize"
- argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], "
+ argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\'], "
}
member_method {
name: "lbeta"
@@ -230,7 +230,7 @@
}
member_method {
name: "log_softmax"
- argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'logits\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "logical_and"
@@ -290,43 +290,43 @@
}
member_method {
name: "reduce_all"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_any"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_logsumexp"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_max"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_mean"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_min"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_prod"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_std"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_sum"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_variance"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "rint"
@@ -342,7 +342,7 @@
}
member_method {
name: "scalar_mul"
- argspec: "args=[\'scalar\', \'x\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'scalar\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "segment_max"
@@ -382,7 +382,7 @@
}
member_method {
name: "softmax"
- argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'logits\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "softplus"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt
index f53567a..f8e12f8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.Accuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Accuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt
index f53567a..b9bc6a7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-binary-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.BinaryAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.BinaryAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\', \'threshold\'], varargs=None, keywords=None, defaults=[\'binary_accuracy\', \'None\', \'0.5\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt
index f53567a..0ef75d8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-categorical-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.CategoricalAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.CategoricalAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'categorical_accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt
similarity index 75%
rename from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
rename to tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt
index f53567a..33226a2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-negatives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.FalseNegatives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.FalseNegatives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt
index f53567a..9953162 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-false-positives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.FalsePositives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.FalsePositives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt
index f53567a..7fe6d6f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-mean.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.Mean"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'mean\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'values\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt
index f53567a..8c3271a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-precision.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.Precision"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Precision\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt
index f53567a..840a68b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-recall.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.Recall"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Recall\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +14,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +62,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +83,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +103,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +111,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +138,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +174,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt
index f53567a..7bce43f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-sparse-categorical-accuracy.pbtxt
@@ -1,8 +1,9 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.SparseCategoricalAccuracy"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.SparseCategoricalAccuracy\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.MeanMetricWrapper\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Mean\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +16,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +64,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +85,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'sparse_categorical_accuracy\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +113,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +140,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +176,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt
index f53567a..83cd5b7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-negatives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.TrueNegatives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.TrueNegatives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt
similarity index 75%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
copy to tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt
index f53567a..5b2502e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.-true-positives.pbtxt
@@ -1,8 +1,8 @@
-path: "tensorflow.nn.rnn_cell.MultiRNNCell"
+path: "tensorflow.metrics.TruePositives"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.MultiRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.TruePositives\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics._ConfusionMatrixConditionCount\'>"
+ is_instance: "<class \'tensorflow.python.keras.metrics.Metric\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
is_instance: "<type \'object\'>"
@@ -15,10 +15,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -67,18 +63,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -100,7 +84,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cells\', \'state_is_tuple\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'thresholds\', \'name\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -120,7 +104,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'aggregation\', \'synchronization\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'VariableAggregation.SUM\', \'VariableSynchronization.ON_READ\', \'None\'], "
}
member_method {
name: "apply"
@@ -128,11 +112,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
@@ -155,10 +139,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -195,11 +175,19 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "reset_states"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "result"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
+ name: "update_state"
+ argspec: "args=[\'self\', \'y_true\', \'y_pred\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt
index e9b996c..773efd0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.metrics.pbtxt
@@ -1,135 +1,47 @@
path: "tensorflow.metrics"
tf_module {
- member_method {
- name: "accuracy"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "Accuracy"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "auc"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'curve\', \'name\', \'summation_method\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'ROC\', \'None\', \'trapezoidal\'], "
+ member {
+ name: "BinaryAccuracy"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "average_precision_at_k"
- argspec: "args=[\'labels\', \'predictions\', \'k\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "CategoricalAccuracy"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "false_negatives"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "FalseNegatives"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "false_negatives_at_thresholds"
- argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "FalsePositives"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "false_positives"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "Mean"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "false_positives_at_thresholds"
- argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "Precision"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "mean"
- argspec: "args=[\'values\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "Recall"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "mean_absolute_error"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "SparseCategoricalAccuracy"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "mean_cosine_distance"
- argspec: "args=[\'labels\', \'predictions\', \'dim\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "TrueNegatives"
+ mtype: "<type \'type\'>"
}
- member_method {
- name: "mean_iou"
- argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "mean_per_class_accuracy"
- argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "mean_relative_error"
- argspec: "args=[\'labels\', \'predictions\', \'normalizer\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "mean_squared_error"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "mean_tensor"
- argspec: "args=[\'values\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "percentage_below"
- argspec: "args=[\'values\', \'threshold\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "precision"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "precision_at_k"
- argspec: "args=[\'labels\', \'predictions\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "precision_at_thresholds"
- argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "precision_at_top_k"
- argspec: "args=[\'labels\', \'predictions_idx\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "recall"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "recall_at_k"
- argspec: "args=[\'labels\', \'predictions\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "recall_at_thresholds"
- argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "recall_at_top_k"
- argspec: "args=[\'labels\', \'predictions_idx\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "root_mean_squared_error"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "sensitivity_at_specificity"
- argspec: "args=[\'labels\', \'predictions\', \'specificity\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "sparse_average_precision_at_k"
- argspec: "args=[\'labels\', \'predictions\', \'k\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "sparse_precision_at_k"
- argspec: "args=[\'labels\', \'predictions\', \'k\', \'class_id\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "specificity_at_sensitivity"
- argspec: "args=[\'labels\', \'predictions\', \'sensitivity\', \'weights\', \'num_thresholds\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'200\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "true_negatives"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "true_negatives_at_thresholds"
- argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "true_positives"
- argspec: "args=[\'labels\', \'predictions\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "true_positives_at_thresholds"
- argspec: "args=[\'labels\', \'predictions\', \'thresholds\', \'weights\', \'metrics_collections\', \'updates_collections\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ member {
+ name: "TruePositives"
+ mtype: "<type \'type\'>"
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
index 34ca207..04b1897 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
@@ -41,8 +41,8 @@
argspec: "args=[\'value\', \'bias\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
- name: "bidirectional_dynamic_rnn"
- argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'sequence_length\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'parallel_iterations\', \'swap_memory\', \'time_major\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'False\', \'None\'], "
+ name: "collapse_repeated"
+ argspec: "args=[\'labels\', \'seq_length\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_accidental_hits"
@@ -98,7 +98,11 @@
}
member_method {
name: "ctc_loss"
- argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'True\'], "
+ argspec: "args=[\'labels\', \'logits\', \'label_length\', \'logit_length\', \'logits_time_major\', \'unique\', \'blank_index\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "ctc_unique_labels"
+ argspec: "args=[\'labels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "depth_to_space"
@@ -118,11 +122,11 @@
}
member_method {
name: "dilation2d"
- argspec: "args=[\'input\', \'filter\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'filters\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "dropout"
- argspec: "args=[\'x\', \'keep_prob\', \'noise_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'x\', \'rate\', \'noise_shape\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "elu"
@@ -166,7 +170,7 @@
}
member_method {
name: "l2_normalize"
- argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], "
+ argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\'], "
}
member_method {
name: "leaky_relu"
@@ -186,7 +190,7 @@
}
member_method {
name: "log_softmax"
- argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'logits\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "lrn"
@@ -202,7 +206,7 @@
}
member_method {
name: "max_pool_with_argmax"
- argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'Targmax\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
+ argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'output_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
name: "moments"
@@ -258,7 +262,7 @@
}
member_method {
name: "softmax"
- argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'logits\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "softmax_cross_entropy_with_logits"
@@ -285,10 +289,6 @@
argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
- name: "static_bidirectional_rnn"
- argspec: "args=[\'cell_fw\', \'cell_bw\', \'inputs\', \'initial_state_fw\', \'initial_state_bw\', \'dtype\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "static_state_saving_rnn"
argspec: "args=[\'cell\', \'inputs\', \'state_saver\', \'state_name\', \'sequence_length\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
index 3c78b07..b1f687f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
@@ -13,10 +13,6 @@
mtype: "<type \'type\'>"
}
member {
- name: "MultiRNNCell"
- mtype: "<type \'type\'>"
- }
- member {
name: "RNNCell"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index b03c8c2..0659900 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -205,10 +205,6 @@
mtype: "<type \'module\'>"
}
member {
- name: "flags"
- mtype: "<type \'module\'>"
- }
- member {
name: "float16"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
@@ -221,10 +217,6 @@
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
- name: "gfile"
- mtype: "<type \'module\'>"
- }
- member {
name: "glorot_uniform_initializer"
mtype: "<type \'type\'>"
}
@@ -305,10 +297,6 @@
mtype: "<type \'type\'>"
}
member {
- name: "pywrap_tensorflow"
- mtype: "<type \'module\'>"
- }
- member {
name: "qint16"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
@@ -453,6 +441,10 @@
argspec: "args=[\'input\', \'axis\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int64\'>\", \'None\'], "
}
member_method {
+ name: "argsort"
+ argspec: "args=[\'values\', \'axis\', \'direction\', \'stable\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'ASCENDING\', \'False\', \'None\'], "
+ }
+ member_method {
name: "as_dtype"
argspec: "args=[\'type_value\'], varargs=None, keywords=None, defaults=None"
}
@@ -506,10 +498,6 @@
}
member_method {
name: "batch_to_space"
- argspec: "args=[\'input\', \'crops\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "batch_to_space_nd"
argspec: "args=[\'input\', \'block_shape\', \'crops\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
@@ -570,7 +558,7 @@
}
member_method {
name: "constant"
- argspec: "args=[\'value\', \'dtype\', \'shape\', \'name\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Const\', \'False\'], "
+ argspec: "args=[\'value\', \'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'Const\'], "
}
member_method {
name: "control_dependencies"
@@ -589,10 +577,6 @@
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "create_partitioned_variables"
- argspec: "args=[\'shape\', \'slicing\', \'initializer\', \'dtype\', \'trainable\', \'collections\', \'name\', \'reuse\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'True\', \'None\', \'None\', \'None\'], "
- }
- member_method {
name: "cumsum"
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
}
@@ -665,10 +649,6 @@
argspec: "args=[\'dims\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "fixed_size_partitioner"
- argspec: "args=[\'num_shards\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], "
- }
- member_method {
name: "floor"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -694,15 +674,19 @@
}
member_method {
name: "gather"
- argspec: "args=[\'params\', \'indices\', \'validate_indices\', \'name\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0\'], "
+ argspec: "args=[\'params\', \'indices\', \'validate_indices\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
}
member_method {
name: "gather_nd"
argspec: "args=[\'params\', \'indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "get_logger"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "gradients"
- argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
+ argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
}
member_method {
name: "greater"
@@ -722,7 +706,7 @@
}
member_method {
name: "hessians"
- argspec: "args=[\'ys\', \'xs\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\'], varargs=None, keywords=None, defaults=[\'hessians\', \'False\', \'False\', \'None\'], "
+ argspec: "args=[\'ys\', \'xs\', \'gate_gradients\', \'aggregation_method\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'hessians\'], "
}
member_method {
name: "histogram_fixed_width"
@@ -817,10 +801,6 @@
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
- name: "min_max_variable_partitioner"
- argspec: "args=[\'max_partitions\', \'axis\', \'min_slice_size\', \'bytes_per_string_element\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'262144\', \'16\'], "
- }
- member_method {
name: "minimum"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -850,7 +830,7 @@
}
member_method {
name: "norm"
- argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "not_equal"
@@ -902,35 +882,35 @@
}
member_method {
name: "reduce_all"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_any"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_logsumexp"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_max"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_mean"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_min"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_prod"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "reduce_sum"
- argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
}
member_method {
name: "register_tensor_conversion_function"
@@ -941,10 +921,6 @@
argspec: "args=[\'input_shape\', \'block_shape\', \'base_paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
- name: "reset_default_graph"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "reshape"
argspec: "args=[\'tensor\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -970,7 +946,7 @@
}
member_method {
name: "scalar_mul"
- argspec: "args=[\'scalar\', \'x\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'scalar\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "scan"
@@ -1037,6 +1013,10 @@
argspec: "args=[\'input_\', \'begin\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "sort"
+ argspec: "args=[\'values\', \'axis\', \'direction\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'ASCENDING\', \'None\'], "
+ }
+ member_method {
name: "space_to_batch_nd"
argspec: "args=[\'input\', \'block_shape\', \'paddings\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1045,10 +1025,6 @@
argspec: "args=[\'axis\', \'sp_inputs\', \'expand_nonconcat_dim\', \'concat_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\'], "
}
member_method {
- name: "sparse_to_dense"
- argspec: "args=[\'sparse_indices\', \'output_shape\', \'sparse_values\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
- }
- member_method {
name: "split"
argspec: "args=[\'value\', \'num_or_size_splits\', \'axis\', \'num\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \'split\'], "
}
@@ -1062,7 +1038,7 @@
}
member_method {
name: "squeeze"
- argspec: "args=[\'input\', \'axis\', \'name\', \'squeeze_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ argspec: "args=[\'input\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "stack"
@@ -1110,17 +1086,13 @@
}
member_method {
name: "transpose"
- argspec: "args=[\'a\', \'perm\', \'name\', \'conjugate\'], varargs=None, keywords=None, defaults=[\'None\', \'transpose\', \'False\'], "
+ argspec: "args=[\'a\', \'perm\', \'conjugate\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'transpose\'], "
}
member_method {
name: "truediv"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
- name: "truncated_normal"
- argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
- }
- member_method {
name: "truncatediv"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -1130,7 +1102,7 @@
}
member_method {
name: "tuple"
- argspec: "args=[\'tensors\', \'name\', \'control_inputs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'tensors\', \'control_inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "unique"
@@ -1149,10 +1121,6 @@
argspec: "args=[\'value\', \'num\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'unstack\'], "
}
member_method {
- name: "variable_axis_size_partitioner"
- argspec: "args=[\'max_shard_bytes\', \'axis\', \'bytes_per_string_element\', \'max_shards\'], varargs=None, keywords=None, defaults=[\'0\', \'16\', \'None\'], "
- }
- member_method {
name: "variable_creator_scope"
argspec: "args=[\'variable_creator\'], varargs=None, keywords=None, defaults=None"
}
@@ -1162,7 +1130,7 @@
}
member_method {
name: "while_loop"
- argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'name\', \'maximum_iterations\', \'return_same_structure\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'None\', \'False\'], "
+ argspec: "args=[\'cond\', \'body\', \'loop_vars\', \'shape_invariants\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\', \'return_same_structure\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "zeros"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
index 2948b73..632c2f8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
@@ -34,7 +34,7 @@
}
member_method {
name: "quantize_and_dequantize"
- argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'None\'], "
+ argspec: "args=[\'input\', \'input_min\', \'input_max\', \'signed_input\', \'num_bits\', \'range_given\', \'round_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'8\', \'False\', \'HALF_TO_EVEN\', \'None\'], "
}
member_method {
name: "quantized_concat"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt
index ce8d277..de5cb6b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt
@@ -29,7 +29,7 @@
argspec: "args=[\'shape\', \'lam\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\', \'None\'], "
}
member_method {
- name: "set_random_seed"
+ name: "set_seed"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
}
member_method {
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 16b7f14..2ccb3a4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -14,7 +14,7 @@
}
member_method {
name: "reduce_join"
- argspec: "args=[\'inputs\', \'axis\', \'keep_dims\', \'separator\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\', \'None\'], "
+ argspec: "args=[\'inputs\', \'axis\', \'keepdims\', \'separator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\'], "
}
member_method {
name: "regex_full_match"
@@ -34,7 +34,7 @@
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'unit\', \'name\'], varargs=None, keywords=None, defaults=[\'BYTE\', \'None\'], "
}
member_method {
name: "to_hash_bucket"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
index 26c979c..42a74a6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.summary.pbtxt
@@ -44,12 +44,4 @@
name: "import_event"
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
- member_method {
- name: "record_summaries"
- argspec: "args=[\'boolean\'], varargs=None, keywords=None, defaults=[\'True\'], "
- }
- member_method {
- name: "should_record_summaries"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt
index df528e2..6fc489c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.test.-benchmark.pbtxt
@@ -7,6 +7,10 @@
name: "__init__"
}
member_method {
+ name: "evaluate"
+ argspec: "args=[\'self\', \'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "is_abstract"
argspec: "args=[\'cls\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-adadelta-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-adadelta-optimizer.pbtxt
deleted file mode 100644
index 1f1d8b6..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-adadelta-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.AdadeltaOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.adadelta.AdadeltaOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'rho\', \'epsilon\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.95\', \'1e-08\', \'False\', \'Adadelta\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-d-a-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-d-a-optimizer.pbtxt
deleted file mode 100644
index a7c05d4..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-d-a-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.AdagradDAOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.adagrad_da.AdagradDAOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'global_step\', \'initial_gradient_squared_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'0.0\', \'0.0\', \'False\', \'AdagradDA\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-optimizer.pbtxt
deleted file mode 100644
index bc8b923..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-adagrad-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.AdagradOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.adagrad.AdagradOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'False\', \'Adagrad\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-adam-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-adam-optimizer.pbtxt
deleted file mode 100644
index 5d17be9..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-adam-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.AdamOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.adam.AdamOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'beta1\', \'beta2\', \'epsilon\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.001\', \'0.9\', \'0.999\', \'1e-08\', \'False\', \'Adam\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-chief-session-creator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-chief-session-creator.pbtxt
deleted file mode 100644
index abbe273..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-chief-session-creator.pbtxt
+++ /dev/null
@@ -1,14 +0,0 @@
-path: "tensorflow.train.ChiefSessionCreator"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.monitored_session.ChiefSessionCreator\'>"
- is_instance: "<class \'tensorflow.python.training.monitored_session.SessionCreator\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'scaffold\', \'master\', \'config\', \'checkpoint_dir\', \'checkpoint_filename_with_path\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "create_session"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-ftrl-optimizer.pbtxt
deleted file mode 100644
index d265fde..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-ftrl-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.FtrlOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.ftrl.FtrlOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'learning_rate_power\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\', \'accum_name\', \'linear_name\', \'l2_shrinkage_regularization_strength\'], varargs=None, keywords=None, defaults=[\'-0.5\', \'0.1\', \'0.0\', \'0.0\', \'False\', \'Ftrl\', \'None\', \'None\', \'0.0\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-gradient-descent-optimizer.pbtxt
deleted file mode 100644
index c673e29..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-gradient-descent-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.GradientDescentOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.gradient_descent.GradientDescentOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'GradientDescent\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-looper-thread.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-looper-thread.pbtxt
deleted file mode 100644
index c618590..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-looper-thread.pbtxt
+++ /dev/null
@@ -1,73 +0,0 @@
-path: "tensorflow.train.LooperThread"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.coordinator.LooperThread\'>"
- is_instance: "<class \'threading.Thread\'>"
- member {
- name: "daemon"
- mtype: "<type \'property\'>"
- }
- member {
- name: "ident"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'coord\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "getName"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "isAlive"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "isDaemon"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "is_alive"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "join"
- argspec: "args=[\'self\', \'timeout\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "loop"
- argspec: "args=[\'coord\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "run"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "run_loop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "setDaemon"
- argspec: "args=[\'self\', \'daemonic\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "setName"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "start"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "start_loop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "stop_loop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-momentum-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-momentum-optimizer.pbtxt
deleted file mode 100644
index 8199f63..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-momentum-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.MomentumOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.momentum.MomentumOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'momentum\', \'use_locking\', \'name\', \'use_nesterov\'], varargs=None, keywords=None, defaults=[\'False\', \'Momentum\', \'False\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.-step-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.-step-context.pbtxt
deleted file mode 100644
index 03efe66..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.-step-context.pbtxt
+++ /dev/null
@@ -1,21 +0,0 @@
-path: "tensorflow.train.MonitoredSession.StepContext"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.monitored_session.StepContext\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "session"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'session\', \'run_with_hooks_fn\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "request_stop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "run_with_hooks"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.pbtxt
deleted file mode 100644
index 09b7b3f..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-monitored-session.pbtxt
+++ /dev/null
@@ -1,34 +0,0 @@
-path: "tensorflow.train.MonitoredSession"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.monitored_session.MonitoredSession\'>"
- is_instance: "<class \'tensorflow.python.training.monitored_session._MonitoredSession\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "StepContext"
- mtype: "<type \'type\'>"
- }
- member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'session_creator\', \'hooks\', \'stop_grace_period_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'120\'], "
- }
- member_method {
- name: "close"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "run"
- argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "run_step_fn"
- argspec: "args=[\'self\', \'step_fn\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "should_stop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-optimizer.pbtxt
deleted file mode 100644
index 876bb35..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-optimizer.pbtxt
+++ /dev/null
@@ -1,50 +0,0 @@
-path: "tensorflow.train.Optimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-adagrad-optimizer.pbtxt
deleted file mode 100644
index 14349a7..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-proximal-adagrad-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.ProximalAdagradOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.proximal_adagrad.ProximalAdagradOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'initial_accumulator_value\', \'l1_regularization_strength\', \'l2_regularization_strength\', \'use_locking\', \'name\'], varargs=None, keywords=None, defaults=[\'0.1\', \'0.0\', \'0.0\', \'False\', \'ProximalAdagrad\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-r-m-s-prop-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-r-m-s-prop-optimizer.pbtxt
deleted file mode 100644
index 906384a..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-r-m-s-prop-optimizer.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-path: "tensorflow.train.RMSPropOptimizer"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.rmsprop.RMSPropOptimizer\'>"
- is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "GATE_GRAPH"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_NONE"
- mtype: "<type \'int\'>"
- }
- member {
- name: "GATE_OP"
- mtype: "<type \'int\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'learning_rate\', \'decay\', \'momentum\', \'epsilon\', \'use_locking\', \'centered\', \'name\'], varargs=None, keywords=None, defaults=[\'0.9\', \'0.0\', \'1e-10\', \'False\', \'False\', \'RMSProp\'], "
- }
- member_method {
- name: "apply_gradients"
- argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compute_gradients"
- argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
- }
- member_method {
- name: "get_name"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot"
- argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_slot_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "minimize"
- argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
- }
- member_method {
- name: "variables"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-creator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-creator.pbtxt
deleted file mode 100644
index beb2327..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-creator.pbtxt
+++ /dev/null
@@ -1,12 +0,0 @@
-path: "tensorflow.train.SessionCreator"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.monitored_session.SessionCreator\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- }
- member_method {
- name: "create_session"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-manager.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-session-manager.pbtxt
deleted file mode 100644
index 448764f..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-session-manager.pbtxt
+++ /dev/null
@@ -1,21 +0,0 @@
-path: "tensorflow.train.SessionManager"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.session_manager.SessionManager\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'local_init_op\', \'ready_op\', \'ready_for_local_init_op\', \'graph\', \'recovery_wait_secs\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'30\', \'None\'], "
- }
- member_method {
- name: "prepare_session"
- argspec: "args=[\'self\', \'master\', \'init_op\', \'saver\', \'checkpoint_dir\', \'checkpoint_filename_with_path\', \'wait_for_checkpoint\', \'max_wait_secs\', \'config\', \'init_feed_dict\', \'init_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'False\', \'7200\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "recover_session"
- argspec: "args=[\'self\', \'master\', \'saver\', \'checkpoint_dir\', \'checkpoint_filename_with_path\', \'wait_for_checkpoint\', \'max_wait_secs\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'7200\', \'None\'], "
- }
- member_method {
- name: "wait_for_session"
- argspec: "args=[\'self\', \'master\', \'config\', \'max_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'inf\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.-step-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.-step-context.pbtxt
deleted file mode 100644
index 36d8ce7..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.-step-context.pbtxt
+++ /dev/null
@@ -1,21 +0,0 @@
-path: "tensorflow.train.SingularMonitoredSession.StepContext"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.monitored_session.StepContext\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "session"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'session\', \'run_with_hooks_fn\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "request_stop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "run_with_hooks"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.pbtxt
deleted file mode 100644
index de0f2c1..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-singular-monitored-session.pbtxt
+++ /dev/null
@@ -1,38 +0,0 @@
-path: "tensorflow.train.SingularMonitoredSession"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.monitored_session.SingularMonitoredSession\'>"
- is_instance: "<class \'tensorflow.python.training.monitored_session._MonitoredSession\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "StepContext"
- mtype: "<type \'type\'>"
- }
- member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'hooks\', \'scaffold\', \'master\', \'config\', \'checkpoint_dir\', \'stop_grace_period_secs\', \'checkpoint_filename_with_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\', \'None\', \'None\', \'120\', \'None\'], "
- }
- member_method {
- name: "close"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "raw_session"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "run"
- argspec: "args=[\'self\', \'fetches\', \'feed_dict\', \'options\', \'run_metadata\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "run_step_fn"
- argspec: "args=[\'self\', \'step_fn\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "should_stop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-supervisor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-supervisor.pbtxt
deleted file mode 100644
index 9677e5a..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-supervisor.pbtxt
+++ /dev/null
@@ -1,153 +0,0 @@
-path: "tensorflow.train.Supervisor"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.supervisor.Supervisor\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "USE_DEFAULT"
- mtype: "<type \'int\'>"
- }
- member {
- name: "coord"
- mtype: "<type \'property\'>"
- }
- member {
- name: "global_step"
- mtype: "<type \'property\'>"
- }
- member {
- name: "init_feed_dict"
- mtype: "<type \'property\'>"
- }
- member {
- name: "init_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "is_chief"
- mtype: "<type \'property\'>"
- }
- member {
- name: "ready_for_local_init_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "ready_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_model_secs"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_path"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_summaries_secs"
- mtype: "<type \'property\'>"
- }
- member {
- name: "saver"
- mtype: "<type \'property\'>"
- }
- member {
- name: "session_manager"
- mtype: "<type \'property\'>"
- }
- member {
- name: "summary_op"
- mtype: "<type \'property\'>"
- }
- member {
- name: "summary_writer"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "Loop"
- argspec: "args=[\'self\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "PrepareSession"
- argspec: "args=[\'self\', \'master\', \'config\', \'wait_for_checkpoint\', \'max_wait_secs\', \'start_standard_services\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'False\', \'7200\', \'True\'], "
- }
- member_method {
- name: "RequestStop"
- argspec: "args=[\'self\', \'ex\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "ShouldStop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "StartQueueRunners"
- argspec: "args=[\'self\', \'sess\', \'queue_runners\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "StartStandardServices"
- argspec: "args=[\'self\', \'sess\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "Stop"
- argspec: "args=[\'self\', \'threads\', \'close_summary_writer\', \'ignore_live_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'False\'], "
- }
- member_method {
- name: "StopOnException"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "SummaryComputed"
- argspec: "args=[\'self\', \'sess\', \'summary\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "WaitForStop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'graph\', \'ready_op\', \'ready_for_local_init_op\', \'is_chief\', \'init_op\', \'init_feed_dict\', \'local_init_op\', \'logdir\', \'summary_op\', \'saver\', \'global_step\', \'save_summaries_secs\', \'save_model_secs\', \'recovery_wait_secs\', \'stop_grace_secs\', \'checkpoint_basename\', \'session_manager\', \'summary_writer\', \'init_fn\', \'local_init_run_options\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'0\', \'True\', \'0\', \'None\', \'0\', \'None\', \'0\', \'0\', \'0\', \'120\', \'600\', \'30\', \'120\', \'model.ckpt\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "loop"
- argspec: "args=[\'self\', \'timer_interval_secs\', \'target\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "managed_session"
- argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
- }
- member_method {
- name: "prepare_or_wait_for_session"
- argspec: "args=[\'self\', \'master\', \'config\', \'wait_for_checkpoint\', \'max_wait_secs\', \'start_standard_services\'], varargs=None, keywords=None, defaults=[\'\', \'None\', \'False\', \'7200\', \'True\'], "
- }
- member_method {
- name: "request_stop"
- argspec: "args=[\'self\', \'ex\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "should_stop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "start_queue_runners"
- argspec: "args=[\'self\', \'sess\', \'queue_runners\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "start_standard_services"
- argspec: "args=[\'self\', \'sess\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "stop"
- argspec: "args=[\'self\', \'threads\', \'close_summary_writer\', \'ignore_live_threads\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'False\'], "
- }
- member_method {
- name: "stop_on_exception"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary_computed"
- argspec: "args=[\'self\', \'sess\', \'summary\', \'global_step\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "wait_for_stop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
deleted file mode 100644
index 39b946b..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt
+++ /dev/null
@@ -1,43 +0,0 @@
-path: "tensorflow.train.VocabInfo"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
- is_instance: "<class \'tensorflow.python.training.warm_starting_util.VocabInfo\'>"
- is_instance: "<type \'tuple\'>"
- member {
- name: "axis"
- mtype: "<type \'property\'>"
- }
- member {
- name: "backup_initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "new_vocab"
- mtype: "<type \'property\'>"
- }
- member {
- name: "new_vocab_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "num_oov_buckets"
- mtype: "<type \'property\'>"
- }
- member {
- name: "old_vocab"
- mtype: "<type \'property\'>"
- }
- member {
- name: "old_vocab_size"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- }
- member_method {
- name: "count"
- }
- member_method {
- name: "index"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-worker-session-creator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-worker-session-creator.pbtxt
deleted file mode 100644
index ac26358..0000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.-worker-session-creator.pbtxt
+++ /dev/null
@@ -1,14 +0,0 @@
-path: "tensorflow.train.WorkerSessionCreator"
-tf_class {
- is_instance: "<class \'tensorflow.python.training.monitored_session.WorkerSessionCreator\'>"
- is_instance: "<class \'tensorflow.python.training.monitored_session.SessionCreator\'>"
- is_instance: "<type \'object\'>"
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'scaffold\', \'master\', \'config\', \'max_wait_secs\'], varargs=None, keywords=None, defaults=[\'None\', \'\', \'None\', \'1800\'], "
- }
- member_method {
- name: "create_session"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index 89d9270..a30f673 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -1,22 +1,6 @@
path: "tensorflow.train"
tf_module {
member {
- name: "AdadeltaOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
- name: "AdagradDAOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
- name: "AdagradOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
- name: "AdamOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
name: "BytesList"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
@@ -37,10 +21,6 @@
mtype: "<type \'type\'>"
}
member {
- name: "ChiefSessionCreator"
- mtype: "<type \'type\'>"
- }
- member {
name: "ClusterDef"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
@@ -89,18 +69,10 @@
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
- name: "FtrlOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
name: "GlobalStepWaiterHook"
mtype: "<type \'type\'>"
}
member {
- name: "GradientDescentOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
name: "Int64List"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
@@ -113,18 +85,6 @@
mtype: "<type \'type\'>"
}
member {
- name: "LooperThread"
- mtype: "<type \'type\'>"
- }
- member {
- name: "MomentumOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
- name: "MonitoredSession"
- mtype: "<type \'type\'>"
- }
- member {
name: "NanLossDuringTrainingError"
mtype: "<type \'type\'>"
}
@@ -133,22 +93,10 @@
mtype: "<type \'type\'>"
}
member {
- name: "Optimizer"
- mtype: "<type \'type\'>"
- }
- member {
- name: "ProximalAdagradOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
name: "ProximalGradientDescentOptimizer"
mtype: "<type \'type\'>"
}
member {
- name: "RMSPropOptimizer"
- mtype: "<type \'type\'>"
- }
- member {
name: "Scaffold"
mtype: "<type \'type\'>"
}
@@ -169,14 +117,6 @@
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
- name: "SessionCreator"
- mtype: "<type \'type\'>"
- }
- member {
- name: "SessionManager"
- mtype: "<type \'type\'>"
- }
- member {
name: "SessionRunArgs"
mtype: "<type \'type\'>"
}
@@ -193,10 +133,6 @@
mtype: "<type \'type\'>"
}
member {
- name: "SingularMonitoredSession"
- mtype: "<type \'type\'>"
- }
- member {
name: "StepCounterHook"
mtype: "<type \'type\'>"
}
@@ -208,18 +144,6 @@
name: "SummarySaverHook"
mtype: "<type \'type\'>"
}
- member {
- name: "Supervisor"
- mtype: "<type \'type\'>"
- }
- member {
- name: "VocabInfo"
- mtype: "<type \'type\'>"
- }
- member {
- name: "WorkerSessionCreator"
- mtype: "<type \'type\'>"
- }
member_method {
name: "cosine_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
@@ -269,7 +193,7 @@
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'initial_variance\', \'variance_decay\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0.55\', \'0.5\', \'0.0\', \'0.001\', \'None\'], "
}
member_method {
- name: "piecewise_constant"
+ name: "piecewise_constant_decay"
argspec: "args=[\'x\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index b0f3742..cba6246 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -126,9 +126,9 @@
filtered_file_list = []
filtered_package_prefixes = ['tensorflow.%s.' % p for p in _NON_CORE_PACKAGES]
for f in golden_file_list:
- if any([
+ if any(
f.rsplit('/')[-1].startswith(pre) for pre in filtered_package_prefixes
- ]):
+ ):
continue
filtered_file_list.append(f)
return filtered_file_list
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.0-cudnn7-ubuntu14.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.0-cudnn7-ubuntu14.04
new file mode 100644
index 0000000..85b9d94
--- /dev/null
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda10.0-cudnn7-ubuntu14.04
@@ -0,0 +1,75 @@
+# To push a new version, run:
+# $ docker build -f Dockerfile.rbe.cuda10.0-cudnn7-ubuntu14.04 \
+# --tag "gcr.io/asci-toolchain/nosla-cuda10.0-cudnn7-ubuntu14.04" .
+# $ docker push gcr.io/asci-toolchain/nosla-cuda10.0-cudnn7-ubuntu14.04
+
+FROM ubuntu:14.04
+LABEL maintainer="Manuel Klimek <klimek@google.com>"
+
+RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates apt-transport-https gnupg-curl && \
+ rm -rf /var/lib/apt/lists/* && \
+ NVIDIA_GPGKEY_SUM=d1be581509378368edeec8c1eb2958702feedf3bc3d17011adbf24efacce4ab5 && \
+ NVIDIA_GPGKEY_FPR=ae09fe4bbd223a84b2ccfce3f60f4b3d7fa2af80 && \
+ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64/7fa2af80.pub && \
+ apt-key adv --export --no-emit-version -a $NVIDIA_GPGKEY_FPR | tail -n +2 > cudasign.pub && \
+ echo "$NVIDIA_GPGKEY_SUM cudasign.pub" | sha256sum -c --strict - && rm cudasign.pub && \
+ echo "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/cuda.list && \
+ echo "deb https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list
+
+ENV CUDA_VERSION 10.0.130
+ENV CUDA_PKG_VERSION 10-0=$CUDA_VERSION-1
+ENV CUDNN_VERSION 7.3.1.20
+ENV NCCL_VERSION 2.3.5
+ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
+ENV NVIDIA_REQUIRE_CUDA "cuda>=10.0,driver>=410"
+ENV NVIDIA_VISIBLE_DEVICES all
+ENV PATH /usr/local/cuda/bin:${PATH}
+
+# TODO(b/110903506): /usr/loca/cuda/lib64/stubs should not be needed in
+# LD_LIBRARY_PATH. The stubs/libcuda.so is not meant to used at runtime. The
+# correct way to pass the path to bfd-ld is to pass
+# -Wl,-rpath-link=/usr/local/cuda/lib64/stubs to all binaries transitively
+# depending on libcuda. Optimally, builds targeting cuda would do that
+# internally.
+ENV LIBRARY_PATH /usr/local/cuda/lib64/stubs
+
+LABEL com.nvidia.cudnn.version="${CUDNN_VERSION}"
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ cuda-command-line-tools-$CUDA_PKG_VERSION \
+ cuda-compat-10-0=410.48-1 \
+ cuda-cudart-$CUDA_PKG_VERSION \
+ cuda-libraries-$CUDA_PKG_VERSION \
+ cuda-libraries-dev-$CUDA_PKG_VERSION \
+ cuda-minimal-build-$CUDA_PKG_VERSION \
+ cuda-nvml-dev-$CUDA_PKG_VERSION \
+ cuda-nvtx-$CUDA_PKG_VERSION \
+ libcudnn7=$CUDNN_VERSION-1+cuda10.0 \
+ libcudnn7=$CUDNN_VERSION-1+cuda10.0 \
+ libcudnn7-dev=$CUDNN_VERSION-1+cuda10.0 \
+ libnccl2=$NCCL_VERSION-2+cuda10.0 \
+ libnccl-dev=$NCCL_VERSION-2+cuda10.0 && \
+ ln -s cuda-10.0 /usr/local/cuda && \
+ apt-mark hold libcudnn7 && \
+ apt-mark hold libnccl2 && \
+ rm -rf /var/lib/apt/lists/*
+
+# TODO(b/110903506): Provide a link to the SONAME of libcuda.so.
+# https://github.com/NVIDIA/nvidia-docker/issues/775
+RUN ln -s libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1
+
+# TODO(klimek): Once the TODO in tensorflow's configure.py to correctly find
+# libnccl is resolved, delete this block.
+RUN ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so \
+ && ln -s /usr/lib/x86_64-linux-gnu/libnccl.so /usr/lib/libnccl.so.2
+
+# Copy and run the install scripts.
+COPY install/*.sh /install/
+ARG DEBIAN_FRONTEND=noninteractive
+RUN /install/install_bootstrap_deb_packages.sh
+RUN add-apt-repository -y ppa:openjdk-r/ppa && \
+ add-apt-repository -y ppa:george-edison55/cmake-3.x
+RUN /install/install_deb_packages.sh
+RUN /install/install_pip_packages.sh
+RUN /install/install_golang.sh
+
diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD
index f46e36b..a6574da 100644
--- a/tensorflow/tools/compatibility/BUILD
+++ b/tensorflow/tools/compatibility/BUILD
@@ -125,6 +125,16 @@
)
py_test(
+ name = "test_file_v1_10",
+ size = "small",
+ srcs = ["testdata/test_file_v1_10.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
name = "test_file_v2_0",
size = "small",
srcs = ["test_file_v2_0.py"],
diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py
index 56c67b8..90bfab3 100644
--- a/tensorflow/tools/compatibility/ast_edits.py
+++ b/tensorflow/tools/compatibility/ast_edits.py
@@ -40,6 +40,10 @@
* `function_reorders`: maps functions whose argument order has changed to the
list of arguments in the new order
* `function_handle`: maps function names to custom handlers for the function
+ * `function_warnings`: maps full names of functions to warnings that will be
+ printed out if the function is used. (e.g. tf.nn.convolution())
+ * `unrestricted_function_warnings`: maps names of functions to warnings that
+ will be printed out when the function is used (e.g. foo.convolution()).
For an example, see `TFAPIChangeSpec`.
"""
@@ -195,6 +199,29 @@
except KeyError:
pass
+ def _print_warning_for_function_unrestricted(self, node):
+ """Print a warning when specific functions are called.
+
+ The function _print_warning_for_function matches the full name of the called
+ function, e.g., tf.foo.bar(). This function matches the function name that
+ is called, as long as the function is an attribute. For example,
+ `tf.foo.bar()` and `foo.bar()` are matched, but not `bar()`.
+
+ Args:
+ node: ast.Call object
+ """
+ function_warnings = getattr(
+ self._api_change_spec, "unrestricted_function_warnings", {})
+ if isinstance(node.func, ast.Attribute):
+ function_name = node.func.attr
+ try:
+ warning_message = function_warnings[function_name]
+ self._file_edit.add(warning_message,
+ node.lineno, node.col_offset, "", "",
+ error="%s requires manual check." % function_name)
+ except KeyError:
+ pass
+
def _get_attribute_full_path(self, node):
"""Traverse an attribute to generate a full name e.g. tf.foo.bar.
@@ -209,11 +236,11 @@
items = []
while not isinstance(curr, ast.Name):
if not isinstance(curr, ast.Attribute):
- return None
+ return None, None
items.append(curr.attr)
curr = curr.value
items.append(curr.id)
- return ".".join(reversed(items))
+ return ".".join(reversed(items)), items[0]
def _find_true_position(self, node):
"""Return correct line number and column offset for a given node.
@@ -276,9 +303,10 @@
Args:
node: Current Node
"""
+ self._print_warning_for_function_unrestricted(node)
# Find a simple attribute name path e.g. "tf.foo.bar"
- full_name = self._get_attribute_full_path(node.func)
+ full_name, name = self._get_attribute_full_path(node.func)
# Make sure the func is marked as being part of a call
node.func.is_function_for_call = True
@@ -286,6 +314,9 @@
if full_name:
# Call special handlers
function_handles = self._api_change_spec.function_handle
+ glob_name = "*.{}".format(name)
+ if glob_name in function_handles:
+ function_handles[glob_name](self._file_edit, node)
if full_name in function_handles:
function_handles[full_name](self._file_edit, node)
@@ -358,7 +389,7 @@
Args:
node: Node that is of type ast.Attribute
"""
- full_name = self._get_attribute_full_path(node)
+ full_name, _ = self._get_attribute_full_path(node)
if full_name:
self._rename_functions(node, full_name)
self._print_warning_for_function(node, full_name)
diff --git a/tensorflow/tools/compatibility/ast_edits_test.py b/tensorflow/tools/compatibility/ast_edits_test.py
index 08f4ae3..99f20a0 100644
--- a/tensorflow/tools/compatibility/ast_edits_test.py
+++ b/tensorflow/tools/compatibility/ast_edits_test.py
@@ -52,6 +52,10 @@
self.function_handle = {}
self.function_reorders = {}
self.function_keyword_renames = {}
+ self.symbol_renames = {}
+ self.function_warnings = {}
+ self.unrestricted_function_warnings = {}
+ self.change_to_function = {}
class RenameKeywordSpec(NoUpdateSpec):
@@ -391,6 +395,26 @@
_, new_text = self._upgrade(RemoveMultipleKeywordArguments(), text)
self.assertIn(new_text, acceptable_outputs)
+ def testUnrestrictedFunctionWarnings(self):
+ class FooWarningSpec(NoUpdateSpec):
+ """Usages of function attribute foo() prints out a warning."""
+
+ def __init__(self):
+ NoUpdateSpec.__init__(self)
+ self.unrestricted_function_warnings = {"foo": "not good"}
+ texts = ["object.foo()", "get_object().foo()",
+ "get_object().foo()", "object.foo().bar()"]
+ for text in texts:
+ (_, report, _), _ = self._upgrade(FooWarningSpec(), text)
+ self.assertIn("not good", report)
+
+ # Note that foo() won't result in a warning, because in this case foo is
+ # not an attribute, but a name.
+ false_alarms = ["foo", "foo()", "foo.bar()", "obj.run_foo()", "obj.foo"]
+ for text in false_alarms:
+ (_, report, _), _ = self._upgrade(FooWarningSpec(), text)
+ self.assertNotIn("not good", report)
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py
index 55a2804..042ca8a 100644
--- a/tensorflow/tools/compatibility/renames_v2.py
+++ b/tensorflow/tools/compatibility/renames_v2.py
@@ -39,6 +39,7 @@
'tf.GRAPH_DEF_VERSION': 'tf.version.GRAPH_DEF_VERSION',
'tf.GRAPH_DEF_VERSION_MIN_CONSUMER': 'tf.version.GRAPH_DEF_VERSION_MIN_CONSUMER',
'tf.GRAPH_DEF_VERSION_MIN_PRODUCER': 'tf.version.GRAPH_DEF_VERSION_MIN_PRODUCER',
+ 'tf.GraphKeys': 'tf.compat.v1.GraphKeys',
'tf.IdentityReader': 'tf.compat.v1.IdentityReader',
'tf.InteractiveSession': 'tf.compat.v1.InteractiveSession',
'tf.LMDBReader': 'tf.compat.v1.LMDBReader',
@@ -91,6 +92,7 @@
'tf.assign': 'tf.compat.v1.assign',
'tf.assign_add': 'tf.compat.v1.assign_add',
'tf.assign_sub': 'tf.compat.v1.assign_sub',
+ 'tf.batch_to_space_nd': 'tf.compat.v1.batch_to_space_nd',
'tf.betainc': 'tf.math.betainc',
'tf.bincount': 'tf.math.bincount',
'tf.ceil': 'tf.math.ceil',
@@ -105,6 +107,7 @@
'tf.convert_to_tensor_or_sparse_tensor': 'tf.compat.v1.convert_to_tensor_or_sparse_tensor',
'tf.count_nonzero': 'tf.compat.v1.count_nonzero',
'tf.count_up_to': 'tf.compat.v1.count_up_to',
+ 'tf.create_partitioned_variables': 'tf.compat.v1.create_partitioned_variables',
'tf.cross': 'tf.linalg.cross',
'tf.cumprod': 'tf.math.cumprod',
'tf.debugging.is_finite': 'tf.math.is_finite',
@@ -114,7 +117,7 @@
'tf.debugging.is_strictly_increasing': 'tf.math.is_strictly_increasing',
'tf.decode_base64': 'tf.io.decode_base64',
'tf.decode_compressed': 'tf.io.decode_compressed',
- 'tf.decode_csv': 'tf.io.decode_csv',
+ 'tf.decode_csv': 'tf.compat.v1.decode_csv',
'tf.decode_json_example': 'tf.io.decode_json_example',
'tf.decode_raw': 'tf.io.decode_raw',
'tf.delete_session_tensor': 'tf.compat.v1.delete_session_tensor',
@@ -153,7 +156,7 @@
'tf.erf': 'tf.math.erf',
'tf.erfc': 'tf.math.erfc',
'tf.expm1': 'tf.math.expm1',
- 'tf.extract_image_patches': 'tf.image.extract_image_patches',
+ 'tf.extract_image_patches': 'tf.compat.v1.extract_image_patches',
'tf.fake_quant_with_min_max_args': 'tf.quantization.fake_quant_with_min_max_args',
'tf.fake_quant_with_min_max_args_gradient': 'tf.quantization.fake_quant_with_min_max_args_gradient',
'tf.fake_quant_with_min_max_vars': 'tf.quantization.fake_quant_with_min_max_vars',
@@ -165,6 +168,7 @@
'tf.fft': 'tf.signal.fft',
'tf.fft2d': 'tf.signal.fft2d',
'tf.fft3d': 'tf.signal.fft3d',
+ 'tf.fixed_size_partitioner': 'tf.compat.v1.fixed_size_partitioner',
'tf.floordiv': 'tf.math.floordiv',
'tf.get_collection': 'tf.compat.v1.get_collection',
'tf.get_collection_ref': 'tf.compat.v1.get_collection_ref',
@@ -176,10 +180,21 @@
'tf.get_session_tensor': 'tf.compat.v1.get_session_tensor',
'tf.get_variable': 'tf.compat.v1.get_variable',
'tf.get_variable_scope': 'tf.compat.v1.get_variable_scope',
+ 'tf.gfile.Copy': 'tf.compat.v1.gfile.Copy',
+ 'tf.gfile.DeleteRecursively': 'tf.compat.v1.gfile.DeleteRecursively',
'tf.gfile.Exists': 'tf.compat.v1.gfile.Exists',
'tf.gfile.FastGFile': 'tf.compat.v1.gfile.FastGFile',
'tf.gfile.GFile': 'tf.compat.v1.gfile.GFile',
+ 'tf.gfile.Glob': 'tf.compat.v1.gfile.Glob',
+ 'tf.gfile.IsDirectory': 'tf.compat.v1.gfile.IsDirectory',
+ 'tf.gfile.ListDirectory': 'tf.compat.v1.gfile.ListDirectory',
+ 'tf.gfile.MakeDirs': 'tf.compat.v1.gfile.MakeDirs',
+ 'tf.gfile.MkDir': 'tf.compat.v1.gfile.MkDir',
'tf.gfile.Open': 'tf.compat.v1.gfile.Open',
+ 'tf.gfile.Remove': 'tf.compat.v1.gfile.Remove',
+ 'tf.gfile.Rename': 'tf.compat.v1.gfile.Rename',
+ 'tf.gfile.Stat': 'tf.compat.v1.gfile.Stat',
+ 'tf.gfile.Walk': 'tf.compat.v1.gfile.Walk',
'tf.global_norm': 'tf.linalg.global_norm',
'tf.global_variables': 'tf.compat.v1.global_variables',
'tf.global_variables_initializer': 'tf.compat.v1.global_variables_initializer',
@@ -198,7 +213,9 @@
'tf.image.resize_area': 'tf.compat.v1.image.resize_area',
'tf.image.resize_bicubic': 'tf.compat.v1.image.resize_bicubic',
'tf.image.resize_bilinear': 'tf.compat.v1.image.resize_bilinear',
+ 'tf.image.resize_images': 'tf.compat.v1.image.resize_images',
'tf.image.resize_nearest_neighbor': 'tf.compat.v1.image.resize_nearest_neighbor',
+ 'tf.image.transpose_image': 'tf.compat.v1.image.transpose_image',
'tf.initialize_all_tables': 'tf.compat.v1.initialize_all_tables',
'tf.initialize_all_variables': 'tf.compat.v1.initialize_all_variables',
'tf.initialize_local_variables': 'tf.compat.v1.initialize_local_variables',
@@ -208,6 +225,7 @@
'tf.initializers.tables_initializer': 'tf.compat.v1.initializers.tables_initializer',
'tf.initializers.variables': 'tf.compat.v1.initializers.variables',
'tf.invert_permutation': 'tf.math.invert_permutation',
+ 'tf.io.tf_record_iterator': 'tf.compat.v1.io.tf_record_iterator',
'tf.is_finite': 'tf.math.is_finite',
'tf.is_inf': 'tf.math.is_inf',
'tf.is_nan': 'tf.math.is_nan',
@@ -283,7 +301,7 @@
'tf.logical_xor': 'tf.math.logical_xor',
'tf.make_template': 'tf.compat.v1.make_template',
'tf.make_tensor_proto': 'tf.compat.v1.make_tensor_proto',
- 'tf.manip.batch_to_space_nd': 'tf.batch_to_space_nd',
+ 'tf.manip.batch_to_space_nd': 'tf.compat.v1.manip.batch_to_space_nd',
'tf.manip.gather_nd': 'tf.gather_nd',
'tf.manip.reshape': 'tf.reshape',
'tf.manip.reverse': 'tf.reverse',
@@ -302,11 +320,47 @@
'tf.matrix_solve_ls': 'tf.linalg.lstsq',
'tf.matrix_transpose': 'tf.linalg.transpose',
'tf.matrix_triangular_solve': 'tf.linalg.triangular_solve',
+ 'tf.metrics.accuracy': 'tf.compat.v1.metrics.accuracy',
+ 'tf.metrics.auc': 'tf.compat.v1.metrics.auc',
+ 'tf.metrics.average_precision_at_k': 'tf.compat.v1.metrics.average_precision_at_k',
+ 'tf.metrics.false_negatives': 'tf.compat.v1.metrics.false_negatives',
+ 'tf.metrics.false_negatives_at_thresholds': 'tf.compat.v1.metrics.false_negatives_at_thresholds',
+ 'tf.metrics.false_positives': 'tf.compat.v1.metrics.false_positives',
+ 'tf.metrics.false_positives_at_thresholds': 'tf.compat.v1.metrics.false_positives_at_thresholds',
+ 'tf.metrics.mean': 'tf.compat.v1.metrics.mean',
+ 'tf.metrics.mean_absolute_error': 'tf.compat.v1.metrics.mean_absolute_error',
+ 'tf.metrics.mean_cosine_distance': 'tf.compat.v1.metrics.mean_cosine_distance',
+ 'tf.metrics.mean_iou': 'tf.compat.v1.metrics.mean_iou',
+ 'tf.metrics.mean_per_class_accuracy': 'tf.compat.v1.metrics.mean_per_class_accuracy',
+ 'tf.metrics.mean_relative_error': 'tf.compat.v1.metrics.mean_relative_error',
+ 'tf.metrics.mean_squared_error': 'tf.compat.v1.metrics.mean_squared_error',
+ 'tf.metrics.mean_tensor': 'tf.compat.v1.metrics.mean_tensor',
+ 'tf.metrics.percentage_below': 'tf.compat.v1.metrics.percentage_below',
+ 'tf.metrics.precision': 'tf.compat.v1.metrics.precision',
+ 'tf.metrics.precision_at_k': 'tf.compat.v1.metrics.precision_at_k',
+ 'tf.metrics.precision_at_thresholds': 'tf.compat.v1.metrics.precision_at_thresholds',
+ 'tf.metrics.precision_at_top_k': 'tf.compat.v1.metrics.precision_at_top_k',
+ 'tf.metrics.recall': 'tf.compat.v1.metrics.recall',
+ 'tf.metrics.recall_at_k': 'tf.compat.v1.metrics.recall_at_k',
+ 'tf.metrics.recall_at_thresholds': 'tf.compat.v1.metrics.recall_at_thresholds',
+ 'tf.metrics.recall_at_top_k': 'tf.compat.v1.metrics.recall_at_top_k',
+ 'tf.metrics.root_mean_squared_error': 'tf.compat.v1.metrics.root_mean_squared_error',
+ 'tf.metrics.sensitivity_at_specificity': 'tf.compat.v1.metrics.sensitivity_at_specificity',
+ 'tf.metrics.sparse_average_precision_at_k': 'tf.compat.v1.metrics.sparse_average_precision_at_k',
+ 'tf.metrics.sparse_precision_at_k': 'tf.compat.v1.metrics.sparse_precision_at_k',
+ 'tf.metrics.specificity_at_sensitivity': 'tf.compat.v1.metrics.specificity_at_sensitivity',
+ 'tf.metrics.true_negatives': 'tf.compat.v1.metrics.true_negatives',
+ 'tf.metrics.true_negatives_at_thresholds': 'tf.compat.v1.metrics.true_negatives_at_thresholds',
+ 'tf.metrics.true_positives': 'tf.compat.v1.metrics.true_positives',
+ 'tf.metrics.true_positives_at_thresholds': 'tf.compat.v1.metrics.true_positives_at_thresholds',
+ 'tf.min_max_variable_partitioner': 'tf.compat.v1.min_max_variable_partitioner',
'tf.model_variables': 'tf.compat.v1.model_variables',
'tf.moving_average_variables': 'tf.compat.v1.moving_average_variables',
'tf.multinomial': 'tf.compat.v1.multinomial',
+ 'tf.nn.bidirectional_dynamic_rnn': 'tf.compat.v1.nn.bidirectional_dynamic_rnn',
'tf.nn.conv3d_backprop_filter_v2': 'tf.nn.conv3d_backprop_filter',
'tf.nn.ctc_beam_search_decoder_v2': 'tf.nn.ctc_beam_search_decoder',
+ 'tf.nn.ctc_loss_v2': 'tf.nn.ctc_loss',
'tf.nn.depthwise_conv2d_native': 'tf.compat.v1.nn.depthwise_conv2d_native',
'tf.nn.depthwise_conv2d_native_backprop_filter': 'tf.nn.depthwise_conv2d_backprop_filter',
'tf.nn.depthwise_conv2d_native_backprop_input': 'tf.nn.depthwise_conv2d_backprop_input',
@@ -321,14 +375,16 @@
'tf.nn.rnn_cell.BasicRNNCell': 'tf.compat.v1.nn.rnn_cell.BasicRNNCell',
'tf.nn.rnn_cell.GRUCell': 'tf.compat.v1.nn.rnn_cell.GRUCell',
'tf.nn.rnn_cell.LSTMCell': 'tf.compat.v1.nn.rnn_cell.LSTMCell',
+ 'tf.nn.rnn_cell.MultiRNNCell': 'tf.compat.v1.nn.rnn_cell.MultiRNNCell',
'tf.nn.softmax_cross_entropy_with_logits_v2': 'tf.nn.softmax_cross_entropy_with_logits',
+ 'tf.nn.static_bidirectional_rnn': 'tf.compat.v1.nn.static_bidirectional_rnn',
'tf.nn.static_rnn': 'tf.compat.v1.nn.static_rnn',
'tf.nn.uniform_candidate_sampler': 'tf.random.uniform_candidate_sampler',
'tf.nn.xw_plus_b': 'tf.compat.v1.nn.xw_plus_b',
'tf.op_scope': 'tf.compat.v1.op_scope',
'tf.orthogonal_initializer': 'tf.keras.initializers.Orthogonal',
- 'tf.parse_example': 'tf.io.parse_example',
- 'tf.parse_single_example': 'tf.io.parse_single_example',
+ 'tf.parse_example': 'tf.compat.v1.parse_example',
+ 'tf.parse_single_example': 'tf.compat.v1.parse_single_example',
'tf.parse_single_sequence_example': 'tf.io.parse_single_sequence_example',
'tf.parse_tensor': 'tf.io.parse_tensor',
'tf.placeholder': 'tf.compat.v1.placeholder',
@@ -347,13 +403,14 @@
'tf.python_io.TFRecordCompressionType': 'tf.io.TFRecordCompressionType',
'tf.python_io.TFRecordOptions': 'tf.io.TFRecordOptions',
'tf.python_io.TFRecordWriter': 'tf.io.TFRecordWriter',
- 'tf.python_io.tf_record_iterator': 'tf.io.tf_record_iterator',
+ 'tf.python_io.tf_record_iterator': 'tf.compat.v1.python_io.tf_record_iterator',
'tf.qr': 'tf.linalg.qr',
'tf.quantize': 'tf.quantization.quantize',
'tf.quantize_v2': 'tf.compat.v1.quantize_v2',
'tf.quantized_concat': 'tf.quantization.quantized_concat',
'tf.random.get_seed': 'tf.compat.v1.random.get_seed',
'tf.random.multinomial': 'tf.compat.v1.random.multinomial',
+ 'tf.random.set_random_seed': 'tf.compat.v1.random.set_random_seed',
'tf.random.stateless_multinomial': 'tf.compat.v1.random.stateless_multinomial',
'tf.random_crop': 'tf.image.random_crop',
'tf.random_gamma': 'tf.random.gamma',
@@ -364,9 +421,10 @@
'tf.read_file': 'tf.io.read_file',
'tf.real': 'tf.math.real',
'tf.reciprocal': 'tf.math.reciprocal',
- 'tf.reduce_join': 'tf.strings.reduce_join',
+ 'tf.reduce_join': 'tf.compat.v1.reduce_join',
'tf.regex_replace': 'tf.strings.regex_replace',
'tf.report_uninitialized_variables': 'tf.compat.v1.report_uninitialized_variables',
+ 'tf.reset_default_graph': 'tf.compat.v1.reset_default_graph',
'tf.resource_loader.get_data_files_path': 'tf.compat.v1.resource_loader.get_data_files_path',
'tf.resource_loader.get_path_to_datafile': 'tf.compat.v1.resource_loader.get_path_to_datafile',
'tf.resource_loader.get_root_dir_with_all_resources': 'tf.compat.v1.resource_loader.get_root_dir_with_all_resources',
@@ -434,10 +492,10 @@
'tf.segment_sum': 'tf.math.segment_sum',
'tf.self_adjoint_eig': 'tf.linalg.eigh',
'tf.self_adjoint_eigvals': 'tf.linalg.eigvalsh',
- 'tf.serialize_many_sparse': 'tf.io.serialize_many_sparse',
- 'tf.serialize_sparse': 'tf.io.serialize_sparse',
+ 'tf.serialize_many_sparse': 'tf.compat.v1.serialize_many_sparse',
+ 'tf.serialize_sparse': 'tf.compat.v1.serialize_sparse',
'tf.serialize_tensor': 'tf.io.serialize_tensor',
- 'tf.set_random_seed': 'tf.random.set_random_seed',
+ 'tf.set_random_seed': 'tf.compat.v1.set_random_seed',
'tf.setdiff1d': 'tf.compat.v1.setdiff1d',
'tf.sets.set_difference': 'tf.sets.difference',
'tf.sets.set_intersection': 'tf.sets.intersection',
@@ -449,6 +507,7 @@
'tf.sparse.merge': 'tf.compat.v1.sparse.merge',
'tf.sparse.placeholder': 'tf.compat.v1.sparse.placeholder',
'tf.sparse.reduce_max_sparse': 'tf.compat.v1.sparse.reduce_max_sparse',
+ 'tf.sparse.reduce_sum_sparse': 'tf.compat.v1.sparse.reduce_sum_sparse',
'tf.sparse_add': 'tf.compat.v1.sparse_add',
'tf.sparse_fill_empty_rows': 'tf.sparse.fill_empty_rows',
'tf.sparse_mask': 'tf.sparse.mask',
@@ -459,8 +518,8 @@
'tf.sparse_placeholder': 'tf.compat.v1.sparse_placeholder',
'tf.sparse_reduce_max': 'tf.compat.v1.sparse_reduce_max',
'tf.sparse_reduce_max_sparse': 'tf.compat.v1.sparse_reduce_max_sparse',
- 'tf.sparse_reduce_sum': 'tf.sparse.reduce_sum',
- 'tf.sparse_reduce_sum_sparse': 'tf.sparse.reduce_sum_sparse',
+ 'tf.sparse_reduce_sum': 'tf.compat.v1.sparse_reduce_sum',
+ 'tf.sparse_reduce_sum_sparse': 'tf.compat.v1.sparse_reduce_sum_sparse',
'tf.sparse_reorder': 'tf.sparse.reorder',
'tf.sparse_reset_shape': 'tf.sparse.reset_shape',
'tf.sparse_reshape': 'tf.sparse.reshape',
@@ -473,6 +532,7 @@
'tf.sparse_split': 'tf.compat.v1.sparse_split',
'tf.sparse_tensor_dense_matmul': 'tf.sparse.sparse_dense_matmul',
'tf.sparse_tensor_to_dense': 'tf.sparse.to_dense',
+ 'tf.sparse_to_dense': 'tf.compat.v1.sparse_to_dense',
'tf.sparse_to_indicator': 'tf.sparse.to_indicator',
'tf.sparse_transpose': 'tf.sparse.transpose',
'tf.spectral.dct': 'tf.signal.dct',
@@ -496,6 +556,15 @@
'tf.string_to_hash_bucket_fast': 'tf.strings.to_hash_bucket_fast',
'tf.string_to_hash_bucket_strong': 'tf.strings.to_hash_bucket_strong',
'tf.string_to_number': 'tf.strings.to_number',
+ 'tf.summary.audio': 'tf.compat.v1.summary.audio',
+ 'tf.summary.get_summary_description': 'tf.compat.v1.summary.get_summary_description',
+ 'tf.summary.histogram': 'tf.compat.v1.summary.histogram',
+ 'tf.summary.image': 'tf.compat.v1.summary.image',
+ 'tf.summary.merge': 'tf.compat.v1.summary.merge',
+ 'tf.summary.merge_all': 'tf.compat.v1.summary.merge_all',
+ 'tf.summary.scalar': 'tf.compat.v1.summary.scalar',
+ 'tf.summary.tensor_summary': 'tf.compat.v1.summary.tensor_summary',
+ 'tf.summary.text': 'tf.compat.v1.summary.text',
'tf.svd': 'tf.linalg.svd',
'tf.tables_initializer': 'tf.compat.v1.tables_initializer',
'tf.test.compute_gradient': 'tf.compat.v1.test.compute_gradient',
@@ -511,13 +580,32 @@
'tf.to_int32': 'tf.compat.v1.to_int32',
'tf.to_int64': 'tf.compat.v1.to_int64',
'tf.trace': 'tf.linalg.trace',
+ 'tf.train.ChiefSessionCreator': 'tf.compat.v1.train.ChiefSessionCreator',
+ 'tf.train.MonitoredSession': 'tf.compat.v1.train.MonitoredSession',
+ 'tf.train.LooperThread': 'tf.compat.v1.train.LooperThread',
+ 'tf.train.AdadeltaOptimizer': 'tf.compat.v1.train.AdadeltaOptimizer',
+ 'tf.train.AdagradDAOptimizer': 'tf.compat.v1.train.AdagradDAOptimizer',
+ 'tf.train.AdagradOptimizer': 'tf.compat.v1.train.AdagradOptimizer',
+ 'tf.train.AdamOptimizer': 'tf.compat.v1.train.AdamOptimizer',
+ 'tf.train.FtrlOptimizer': 'tf.compat.v1.train.FtrlOptimizer',
+ 'tf.train.GradientDescentOptimizer': 'tf.compat.v1.train.GradientDescentOptimizer',
+ 'tf.train.MomentumOptimizer': 'tf.compat.v1.train.MomentumOptimizer',
'tf.train.MonitoredTrainingSession': 'tf.compat.v1.train.MonitoredTrainingSession',
'tf.train.NewCheckpointReader': 'tf.compat.v1.train.NewCheckpointReader',
+ 'tf.train.Optimizer': 'tf.compat.v1.train.Optimizer',
'tf.train.ProfilerHook': 'tf.compat.v1.train.ProfilerHook',
+ 'tf.train.ProximalAdagradOptimizer': 'tf.compat.v1.train.ProximalAdagradOptimizer',
'tf.train.QueueRunner': 'tf.compat.v1.train.QueueRunner',
+ 'tf.train.RMSPropOptimizer': 'tf.compat.v1.train.RMSPropOptimizer',
'tf.train.Saver': 'tf.compat.v1.train.Saver',
'tf.train.SaverDef': 'tf.compat.v1.train.SaverDef',
+ 'tf.train.SessionCreator': 'tf.compat.v1.train.SessionCreator',
+ 'tf.train.SessionManager': 'tf.compat.v1.train.SessionManager',
+ 'tf.train.SingularMonitoredSession': 'tf.compat.v1.train.SingularMonitoredSession',
+ 'tf.train.Supervisor': 'tf.compat.v1.train.Supervisor',
'tf.train.SyncReplicasOptimizer': 'tf.compat.v1.train.SyncReplicasOptimizer',
+ 'tf.train.WorkerSessionCreator': 'tf.compat.v1.train.WorkerSessionCreator',
+ 'tf.train.VocabInfo': 'tf.compat.v1.train.VocabInfo',
'tf.train.add_queue_runner': 'tf.compat.v1.train.add_queue_runner',
'tf.train.assert_global_step': 'tf.compat.v1.train.assert_global_step',
'tf.train.basic_train_loop': 'tf.compat.v1.train.basic_train_loop',
@@ -541,6 +629,7 @@
'tf.train.maybe_batch_join': 'tf.compat.v1.train.maybe_batch_join',
'tf.train.maybe_shuffle_batch': 'tf.compat.v1.train.maybe_shuffle_batch',
'tf.train.maybe_shuffle_batch_join': 'tf.compat.v1.train.maybe_shuffle_batch_join',
+ 'tf.train.piecewise_constant': 'tf.compat.v1.train.piecewise_constant',
'tf.train.queue_runner.QueueRunner': 'tf.compat.v1.train.queue_runner.QueueRunner',
'tf.train.queue_runner.add_queue_runner': 'tf.compat.v1.train.queue_runner.add_queue_runner',
'tf.train.queue_runner.start_queue_runners': 'tf.compat.v1.train.queue_runner.start_queue_runners',
@@ -555,6 +644,7 @@
'tf.train.update_checkpoint_state': 'tf.compat.v1.train.update_checkpoint_state',
'tf.train.write_graph': 'tf.io.write_graph',
'tf.trainable_variables': 'tf.compat.v1.trainable_variables',
+ 'tf.truncated_normal': 'tf.random.truncated_normal',
'tf.uniform_unit_scaling_initializer': 'tf.initializers.uniform_unit_scaling',
'tf.unsorted_segment_max': 'tf.math.unsorted_segment_max',
'tf.unsorted_segment_mean': 'tf.math.unsorted_segment_mean',
@@ -562,6 +652,7 @@
'tf.unsorted_segment_prod': 'tf.math.unsorted_segment_prod',
'tf.unsorted_segment_sqrt_n': 'tf.math.unsorted_segment_sqrt_n',
'tf.unsorted_segment_sum': 'tf.math.unsorted_segment_sum',
+ 'tf.variable_axis_size_partitioner': 'tf.compat.v1.variable_axis_size_partitioner',
'tf.variable_op_scope': 'tf.compat.v1.variable_op_scope',
'tf.variable_scope': 'tf.compat.v1.variable_scope',
'tf.variables_initializer': 'tf.compat.v1.variables_initializer',
diff --git a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
index e5ca8d3..fd68878 100644
--- a/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
+++ b/tensorflow/tools/compatibility/testdata/test_file_v1_10.py
@@ -25,10 +25,47 @@
class TestUpgrade(test_util.TensorFlowTestCase):
"""Test various APIs that have been changed in 2.0."""
+ def setUp(self):
+ tf.enable_eager_execution()
+
def testRenames(self):
with self.cached_session():
- self.assertAllClose(1.04719755, tf.acos(0.5).eval())
- self.assertAllClose(0.5, tf.rsqrt(4.0).eval())
+ self.assertAllClose(1.04719755, tf.acos(0.5))
+ self.assertAllClose(0.5, tf.rsqrt(4.0))
+
+ def testSerializeSparseTensor(self):
+ sp_input = tf.SparseTensor(
+ indices=tf.constant([[1]], dtype=tf.int64),
+ values=tf.constant([2], dtype=tf.int64),
+ dense_shape=[2])
+
+ with self.cached_session():
+ serialized_sp = tf.serialize_sparse(sp_input, 'serialize_name', tf.string)
+ self.assertEqual((3,), serialized_sp.shape)
+ self.assertTrue(serialized_sp[0].numpy()) # check non-empty
+
+ def testSerializeManySparse(self):
+ sp_input = tf.SparseTensor(
+ indices=tf.constant([[0, 1]], dtype=tf.int64),
+ values=tf.constant([2], dtype=tf.int64),
+ dense_shape=[1, 2])
+
+ with self.cached_session():
+ serialized_sp = tf.serialize_many_sparse(
+ sp_input, 'serialize_name', tf.string)
+ self.assertEqual((1, 3), serialized_sp.shape)
+
+ def testArgMaxMin(self):
+ self.assertAllClose(
+ [1],
+ tf.argmax([[1, 3, 2]], name='abc', dimension=1))
+ self.assertAllClose(
+ [0, 0, 0],
+ tf.argmax([[1, 3, 2]], dimension=0))
+ self.assertAllClose(
+ [0],
+ tf.argmin([[1, 3, 2]], name='abc', dimension=1))
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 9b2abb9..3cb78af 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -31,9 +31,30 @@
# Maps from a function name to a dictionary that describes how to
# map from an old argument keyword to the new argument keyword.
self.function_keyword_renames = {
+ "tf.argmin": {
+ "dimension": "axis",
+ },
+ "tf.argmax": {
+ "dimension": "axis",
+ },
+ "tf.image.crop_and_resize": {
+ "box_ind": "box_indices",
+ },
+ "tf.image.extract_image_patches": {
+ "ksizes": "sizes",
+ },
+ "tf.extract_image_patches": {
+ "ksizes": "sizes",
+ },
"tf.expand_dims": {
"dim": "axis",
},
+ "tf.batch_to_space_nd": {
+ "block_size": "block_shape",
+ },
+ "tf.constant": {
+ "verify_shapes": "verify_shapes_is_now_always_true",
+ },
"tf.convert_to_tensor": {
"preferred_dtype": "dtype_hint"
},
@@ -51,13 +72,23 @@
"tf.nn.sufficient_statistics": {
"keep_dims": "keepdims"
},
+ "tf.nn.log_softmax": {
+ "dim": "axis",
+ },
+ "tf.nn.softmax": {
+ "dim": "axis",
+ },
"tf.debugging.assert_all_finite": {
"t": "x",
"msg": "message",
},
+ "tf.sparse.add": ["a", "b", "thresh"],
"tf.sparse.split": {
"split_dim": "axis",
},
+ "tf.max_pool_with_argmax": {
+ "Targmax": "output_dtype",
+ },
"tf.multinomial": {
"output_dtype": "dtype",
},
@@ -69,6 +100,13 @@
"m": "mean",
"v": "variance",
},
+ "tf.manip.batch_to_space_nd": {
+ "block_size": "block_shape",
+ },
+ "tf.nn.dilation2d": {
+ "filter": "filters",
+ "rates": "dilations",
+ },
"tf.nn.conv3d": {
"filter": "filters"
},
@@ -89,9 +127,131 @@
"tf.gfile.Exists": {
"filename": "path",
},
+ "tf.gfile.Remove": {
+ "filename": "path",
+ },
+ "tf.gfile.Stat": {
+ "filename": "path",
+ },
+ "tf.gfile.Glob": {
+ "filename": "pattern",
+ },
+ "tf.gfile.MkDir": {
+ "dirname": "path",
+ },
+ "tf.gfile.MakeDirs": {
+ "dirname": "path",
+ },
+ "tf.gfile.DeleteRecursively": {
+ "dirname": "path",
+ },
+ "tf.gfile.IsDirectory": {
+ "dirname": "path",
+ },
+ "tf.gfile.ListDirectory": {
+ "dirname": "path",
+ },
+ "tf.gfile.Copy": {
+ "oldpath": "src",
+ "newpath": "dst",
+ },
+ "tf.gfile.Rename": {
+ "oldpath": "src",
+ "newpath": "dst",
+ },
+ "tf.gfile.Walk": {
+ "in_order": "topdown",
+ },
"tf.random.stateless_multinomial": {
"output_dtype": "dtype",
},
+ "tf.linalg.l2_normalize": {
+ "dim": "axis",
+ },
+ "tf.math.l2_normalize": {
+ "dim": "axis",
+ },
+ "tf.nn.l2_normalize": {
+ "dim": "axis",
+ },
+ "tf.sparse.concat": [
+ "axis", "sp_inputs", "name", "expand_nonconcat_dim", "concat_dim"
+ ],
+ "tf.reduce_all": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.math.reduce_all": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.reduce_any": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.math.reduce_any": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.reduce_min": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.math.reduce_min": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.reduce_max": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.math.reduce_max": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.reduce_sum": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.math.reduce_sum": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.reduce_mean": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.math.reduce_mean": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.reduce_prod": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.math.reduce_prod": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.reduce_logsumexp": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.math.reduce_logsumexp": {
+ "reduction_indices": "axis",
+ "keep_dims": "keepdims",
+ },
+ "tf.reduce_join": {
+ "keep_dims": "keepdims",
+ "reduction_indices": "axis"
+ },
+ "tf.strings.reduce_join": {
+ "keep_dims": "keepdims",
+ "reduction_indices": "axis"
+ },
+ "tf.squeeze": {
+ "squeeze_dims": "axis",
+ },
}
# Mapping from function to the new name of the function
@@ -102,54 +262,142 @@
# function_reorders or function_keyword_renames, use the OLD function name.
# These renames happen after the arguments have been processed.
self.symbol_renames.update({
- "tf.contrib.data.AUTOTUNE": "tf.data.experimental.AUTOTUNE",
- "tf.contrib.data.Counter": "tf.data.experimental.Counter",
- "tf.contrib.data.CheckpointInputPipelineHook": "tf.data.experimental.CheckpointInputPipelineHook",
- "tf.contrib.data.CsvDataset": "tf.data.experimental.CsvDataset",
- "tf.contrib.data.Optional": "tf.data.experimental.Optional",
- "tf.contrib.data.RandomDataset": "tf.data.experimental.RandomDataset",
- "tf.contrib.data.Reducer": "tf.data.experimental.Reducer",
- "tf.contrib.data.SqlDataset": "tf.data.experimental.SqlDataset",
- "tf.contrib.data.StatsAggregator": "tf.data.experimental.StatsAggregator",
- "tf.contrib.data.TFRecordWriter": "tf.data.experimental.TFRecordWriter",
- "tf.contrib.data.assert_element_shape": "tf.data.experimental.assert_element_shape",
- "tf.contrib.data.batch_and_drop_remainder": "tf.compat.v1.contrib.data.batch_and_drop_remainder",
- "tf.contrib.data.bucket_by_sequence_length": "tf.data.experimental.bucket_by_sequence_length",
- "tf.contrib.data.choose_from_datasets": "tf.data.experimental.choose_from_datasets",
- "tf.contrib.data.copy_to_device": "tf.data.experimental.copy_to_device",
- "tf.contrib.data.dense_to_sparse_batch": "tf.data.experimental.dense_to_sparse_batch",
- "tf.contrib.data.enumerate_dataset": "tf.data.experimental.enumerate_dataset",
- "tf.contrib.data.get_next_as_optional": "tf.data.experimental.get_next_as_optional",
- "tf.contrib.data.get_single_element": "tf.data.experimental.get_single_element",
- "tf.contrib.data.group_by_reducer": "tf.data.experimental.group_by_reducer",
- "tf.contrib.data.group_by_window": "tf.data.experimental.group_by_window",
- "tf.contrib.data.ignore_errors": "tf.data.experimental.ignore_errors",
- "tf.contrib.data.latency_stats": "tf.data.experimental.latency_stats",
- "tf.contrib.data.make_batched_features_dataset": "tf.data.experimental.make_batched_features_dataset",
- "tf.contrib.data.make_csv_dataset": "tf.data.experimental.make_csv_dataset",
- "tf.contrib.data.make_saveable_from_iterator": "tf.data.experimental.make_saveable_from_iterator",
- "tf.contrib.data.map_and_batch": "tf.data.experimental.map_and_batch",
- "tf.contrib.data.padded_batch_and_drop_remainder": "tf.compat.v1.contrib.data.padded_batch_and_drop_remainder",
- "tf.contrib.data.parallel_interleave": "tf.data.experimental.parallel_interleave",
- "tf.contrib.data.parse_example_dataset": "tf.data.experimental.parse_example_dataset",
- "tf.contrib.data.prefetch_to_device": "tf.data.experimental.prefetch_to_device",
- "tf.contrib.data.read_batch_features": "tf.compat.v1.contrib.data.read_batch_features",
- "tf.contrib.data.reduce_dataset": "tf.compat.v1.contrib.data.reduce_dataset",
- "tf.contrib.data.rejection_resample": "tf.data.experimental.rejection_resample",
- "tf.contrib.data.sample_from_datasets": "tf.data.experimental.sample_from_datasets",
- "tf.contrib.data.scan": "tf.data.experimental.scan",
- "tf.contrib.data.set_stats_aggregator": "tf.data.experimental.set_stats_aggregator",
- "tf.contrib.data.shuffle_and_repeat": "tf.data.experimental.shuffle_and_repeat",
- "tf.contrib.data.sliding_window_batch": "tf.compat.v1.contrib.data.sliding_window_batch",
- "tf.contrib.data.sloppy_interleave": "tf.compat.v1.contrib.data.sloppy_interleave",
- "tf.contrib.data.unbatch": "tf.data.experimental.unbatch",
- "tf.contrib.data.unique": "tf.data.experimental.unique",
- "tf.quantize_v2": "tf.quantization.quantize",
- "tf.sparse_concat": "tf.sparse.concat",
- "tf.sparse_split": "tf.sparse.split",
- "tf.multinomial": "tf.random.categorical",
- "tf.random.multinomial": "tf.random.categorical",
- "tf.load_file_system_library": "tf.load_library",
+ "tf.batch_to_space_nd":
+ "tf.batch_to_space",
+ "tf.gfile.Copy":
+ "tf.io.gfile.Copy",
+ "tf.gfile.DeleteRecursively":
+ "tf.io.gfile.DeleteRecursively",
+ "tf.gfile.Exists":
+ "tf.io.gfile.Exists",
+ "tf.gfile.Glob":
+ "tf.io.gfile.Glob",
+ "tf.gfile.IsDirectory":
+ "tf.io.gfile.IsDirectory",
+ "tf.gfile.ListDirectory":
+ "tf.io.gfile.ListDirectory",
+ "tf.gfile.MakeDirs":
+ "tf.io.gfile.MakeDirs",
+ "tf.gfile.MkDir":
+ "tf.io.gfile.MkDir",
+ "tf.gfile.Remove":
+ "tf.io.gfile.Remove",
+ "tf.gfile.Rename":
+ "tf.io.gfile.Rename",
+ "tf.gfile.Stat":
+ "tf.io.gfile.Stat",
+ "tf.gfile.Walk":
+ "tf.io.gfile.Walk",
+ "tf.contrib.data.AUTOTUNE":
+ "tf.data.experimental.AUTOTUNE",
+ "tf.contrib.data.Counter":
+ "tf.data.experimental.Counter",
+ "tf.contrib.data.CheckpointInputPipelineHook":
+ "tf.data.experimental.CheckpointInputPipelineHook",
+ "tf.contrib.data.CsvDataset":
+ "tf.data.experimental.CsvDataset",
+ "tf.contrib.data.Optional":
+ "tf.data.experimental.Optional",
+ "tf.contrib.data.RandomDataset":
+ "tf.data.experimental.RandomDataset",
+ "tf.contrib.data.Reducer":
+ "tf.data.experimental.Reducer",
+ "tf.contrib.data.SqlDataset":
+ "tf.data.experimental.SqlDataset",
+ "tf.contrib.data.StatsAggregator":
+ "tf.data.experimental.StatsAggregator",
+ "tf.contrib.data.TFRecordWriter":
+ "tf.data.experimental.TFRecordWriter",
+ "tf.contrib.data.assert_element_shape":
+ "tf.data.experimental.assert_element_shape",
+ "tf.contrib.data.batch_and_drop_remainder":
+ "tf.compat.v1.contrib.data.batch_and_drop_remainder",
+ "tf.contrib.data.bucket_by_sequence_length":
+ "tf.data.experimental.bucket_by_sequence_length",
+ "tf.contrib.data.choose_from_datasets":
+ "tf.data.experimental.choose_from_datasets",
+ "tf.contrib.data.copy_to_device":
+ "tf.data.experimental.copy_to_device",
+ "tf.contrib.data.dense_to_sparse_batch":
+ "tf.data.experimental.dense_to_sparse_batch",
+ "tf.contrib.data.enumerate_dataset":
+ "tf.data.experimental.enumerate_dataset",
+ "tf.contrib.data.get_next_as_optional":
+ "tf.data.experimental.get_next_as_optional",
+ "tf.contrib.data.get_single_element":
+ "tf.data.experimental.get_single_element",
+ "tf.contrib.data.group_by_reducer":
+ "tf.data.experimental.group_by_reducer",
+ "tf.contrib.data.group_by_window":
+ "tf.data.experimental.group_by_window",
+ "tf.contrib.data.ignore_errors":
+ "tf.data.experimental.ignore_errors",
+ "tf.contrib.data.latency_stats":
+ "tf.data.experimental.latency_stats",
+ "tf.contrib.data.make_batched_features_dataset":
+ "tf.data.experimental.make_batched_features_dataset",
+ "tf.contrib.data.make_csv_dataset":
+ "tf.data.experimental.make_csv_dataset",
+ "tf.contrib.data.make_saveable_from_iterator":
+ "tf.data.experimental.make_saveable_from_iterator",
+ "tf.contrib.data.map_and_batch":
+ "tf.data.experimental.map_and_batch",
+ "tf.contrib.data.padded_batch_and_drop_remainder":
+ "tf.compat.v1.contrib.data.padded_batch_and_drop_remainder",
+ "tf.contrib.data.parallel_interleave":
+ "tf.data.experimental.parallel_interleave",
+ "tf.contrib.data.parse_example_dataset":
+ "tf.data.experimental.parse_example_dataset",
+ "tf.contrib.data.prefetch_to_device":
+ "tf.data.experimental.prefetch_to_device",
+ "tf.contrib.data.read_batch_features":
+ "tf.compat.v1.contrib.data.read_batch_features",
+ "tf.contrib.data.reduce_dataset":
+ "tf.compat.v1.contrib.data.reduce_dataset",
+ "tf.contrib.data.rejection_resample":
+ "tf.data.experimental.rejection_resample",
+ "tf.contrib.data.sample_from_datasets":
+ "tf.data.experimental.sample_from_datasets",
+ "tf.contrib.data.scan":
+ "tf.data.experimental.scan",
+ "tf.contrib.data.set_stats_aggregator":
+ "tf.data.experimental.set_stats_aggregator",
+ "tf.contrib.data.shuffle_and_repeat":
+ "tf.data.experimental.shuffle_and_repeat",
+ "tf.contrib.data.sliding_window_batch":
+ "tf.compat.v1.contrib.data.sliding_window_batch",
+ "tf.contrib.data.sloppy_interleave":
+ "tf.compat.v1.contrib.data.sloppy_interleave",
+ "tf.contrib.data.unbatch":
+ "tf.data.experimental.unbatch",
+ "tf.contrib.data.unique":
+ "tf.data.experimental.unique",
+ "tf.contrib.framework.sort":
+ "tf.sort",
+ "tf.contrib.framework.argsort":
+ "tf.argsort",
+ "tf.manip.batch_to_space_nd":
+ "tf.batch_to_space",
+ "tf.quantize_v2":
+ "tf.quantization.quantize",
+ "tf.sparse_concat":
+ "tf.sparse.concat",
+ "tf.sparse_split":
+ "tf.sparse.split",
+ "tf.multinomial":
+ "tf.random.categorical",
+ "tf.random.multinomial":
+ "tf.random.categorical",
+ "tf.load_file_system_library":
+ "tf.load_library",
+ "tf.pywrap_tensorflow":
+ "tf.compat.v1.pywrap_tensorflow",
+ "tf.bincount":
+ "tf.math.bincount",
+ "tf.confusion_matrix":
+ "tf.math.confusion_matrix",
+ "tf.train.confusion_matrix":
+ "tf.math.confusion_matrix",
})
# pylint: enable=line-too-long
@@ -167,13 +415,17 @@
# IMPORTANT: order here should correspond to OLD argument order.
# We just prepend "arg_name=" to all arguments in function calls.
self.function_reorders = {
- "tf.argmax": ["input", "axis", "name", "dimension", "output_type"],
- "tf.argmin": ["input", "axis", "name", "dimension", "output_type"],
+ "tf.io.serialize_sparse": ["sp_input", "name", "out_type"],
+ "tf.io.serialize_many_sparse": ["sp_input", "name", "out_type"],
+ "tf.argmax": ["input", "axis", "name", "axis", "output_type"],
+ "tf.argmin": ["input", "axis", "name", "axis", "output_type"],
+ "tf.batch_to_space": ["input", "crops", "block_size", "name"],
"tf.boolean_mask": ["tensor", "mask", "name", "axis"],
"tf.convert_to_tensor": ["value", "dtype", "name", "preferred_dtype"],
"tf.nn.convolution": [
"input", "filter", "padding", "strides", "dilation_rate", "name",
- "data_format"],
+ "data_format"
+ ],
"tf.nn.crelu": ["features", "name", "axis"],
"tf.nn.pool": [
"input", "window_shape", "pooling_type", "padding", "dilation_rate",
@@ -183,6 +435,7 @@
"input", "filter", "strides", "padding", "rate", "name",
"data_format"
],
+ "tf.manip.batch_to_space_nd": ["input", "crops", "block_size", "name"],
"tf.multinomial": [
"logits", "num_samples", "seed", "name", "output_dtype"
],
@@ -191,15 +444,19 @@
],
"tf.pad": ["tensor", "paddings", "mode", "name", "constant_values"],
"tf.quantize_v2": [
- "input", "min_range", "max_range", "T", "mode", "name",
- "round_mode"
+ "input", "min_range", "max_range", "T", "mode", "name", "round_mode"
+ ],
+ "tf.feature_column.categorical_column_with_vocabulary_file": [
+ "key", "vocabulary_file", "vocabulary_size", "num_oov_buckets",
+ "default_value", "dtype"
],
"tf.shape": ["input", "name", "out_type"],
"tf.size": ["input", "name", "out_type"],
+ "tf.random.poisson": ["lam", "shape", "dtype", "seed", "name"],
+ "tf.sparse.add": ["a", "b", "thresh"],
"tf.sparse.concat": [
"axis", "sp_inputs", "name", "expand_nonconcat_dim", "concat_dim"
],
- "tf.random.poisson": ["lam", "shape", "dtype", "seed", "name"],
"tf.sparse.segment_mean": [
"data", "indices", "segment_ids", "name", "num_segments"
],
@@ -209,11 +466,118 @@
"tf.sparse.segment_sum": [
"data", "indices", "segment_ids", "name", "num_segments"
],
+ "tf.io.decode_csv": [
+ "records",
+ "record_defaults",
+ "field_delim",
+ "use_quote_delim",
+ "name",
+ "na_value",
+ "select_cols",
+ ],
+ "tf.strings.substr": ["input", "pos", "len", "name", "unit"],
+ "tf.strings.reduce_join": [
+ "input", "axis", "keep_dims", "separator", "name",
+ "reduction_indices"
+ ],
"tf.strings.length": ["input", "name", "unit"],
+ "tf.transpose": ["a", "perm", "name", "conjugate"],
+ "tf.tuple": ["tensors", "name", "control_inputs"],
+ "tf.io.parse_example": [
+ "serialized", "features", "name", "example_names"
+ ],
+ "tf.io.parse_single_example": [
+ "serialized", "features", "name", "example_names"
+ ],
+ "tf.while_loop": [
+ "cond", "body", "loop_vars", "shape_invariants",
+ "parallel_iterations", "back_prop", "swap_memory", "name",
+ "maximum_iterations", "return_same_structure"
+ ],
+ "tf.reduce_all": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.math.reduce_all": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.reduce_any": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.math.reduce_any": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.reduce_min": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.math.reduce_min": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.reduce_max": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.math.reduce_max": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.reduce_sum": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.math.reduce_sum": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.reduce_mean": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.math.reduce_mean": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.reduce_prod": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.math.reduce_prod": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.reduce_logsumexp": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.math.reduce_logsumexp": [
+ "input_tensor", "axis", "keepdims", "name", "reduction_indices",
+ "keep_dims"
+ ],
+ "tf.reduce_join": [
+ "input", "axis", "keep_dims", "separator", "name",
+ "reduction_indices"
+ ],
+ "tf.confusion_matrix": [
+ "labels", "predictions", "num_classes", "dtype", "name", "weights"
+ ],
+ "tf.math.confusion_matrix": [
+ "labels", "predictions", "num_classes", "dtype", "name", "weights"
+ ]
}
# Specially handled functions.
- self.function_handle = {}
+ self.function_handle = {
+ "tf.nn.dropout": self._dropout_handler,
+ "tf.gradients": self._colocate_handler("tf.gradients"),
+ "*.minimize": self._colocate_handler("Optimizer.minimize"),
+ "*.compute_gradients":
+ self._colocate_handler("Optimizer.compute_gradients"),
+ }
decay_function_comment = (
"ERROR: <function name> has been changed to return a callable instead "
@@ -322,6 +686,65 @@
if name not in self.function_warnings and name not in excluded_renames
}
+ export_saved_model_renamed = (
+ "(Manual edit required) Please rename the function export_savedmodel() "
+ "to export_saved_model(). Two things to note:\n\t(1) The argument "
+ "strip_default_attributes has been removed. The function will always "
+ "strip the default attributes from ops. If this breaks your code, "
+ "please switch to tf.compat.v1.estimator.Estimator.\n\t(2) This change "
+ "only effects core estimator. If you are using "
+ "tf.contrib.learn.Estimator, please switch to using core estimator.")
+
+ # Specify warnings for functions that aren't restricted to the tf.x.y.z
+ # format. This should only be used for methods with unique names, e.g.
+ # export_savedmodel, which is only defined in Estimator objects.
+ self.unrestricted_function_warnings = {
+ "export_savedmodel": export_saved_model_renamed,
+ }
+
+ @staticmethod
+ def _dropout_handler(file_edit_recorder, node):
+ if len(node.args) < 2:
+ comment = ("ERROR: tf.nn.dropout did not take arguments, so automatic "
+ "transformation was disabled. tf.nn.dropout has changed "
+ "the semantics of the second argument.")
+ file_edit_recorder.add(
+ comment,
+ node.lineno,
+ node.col_offset,
+ "tf.nn.dropout",
+ "tf.nn.dropout",
+ error="tf.nn.dropout requires manual check.")
+ else:
+ comment = ("WARNING: tf.nn.dropout has changed the semantics of the "
+ "second argument. Please check the transformation.\n")
+ file_edit_recorder.add(
+ comment,
+ node.args[1].lineno,
+ node.args[1].col_offset,
+ "",
+ "1 - ")
+
+ @staticmethod
+ def _colocate_handler(name):
+ def _helper(file_edit_recorder, node):
+ for keyword in node.keywords:
+ if keyword.arg == "colocate_gradients_with_ops":
+ # TODO(jhseu): Since ast_edit.py does string replacement, there's no
+ # straightforward way to remove the argument. Try to fix before 2.0 is
+ # final.
+ comment = ("For tf.gradients and tf.Optimizer.minimize, "
+ "colocate_gradients_with_op has been removed and now "
+ "defaults to True.")
+ file_edit_recorder.add(
+ comment,
+ node.lineno,
+ node.col_offset,
+ "",
+ "",
+ error="{} requires manual check.".format(name))
+ return _helper
+
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@@ -329,8 +752,8 @@
description="""Convert a TensorFlow Python file to 2.0
Simple usage:
- tf_convert_v2.py --infile foo.py --outfile bar.py
- tf_convert_v2.py --intree ~/code/old --outtree ~/code/new
+ tf_upgrade_v2.py --infile foo.py --outfile bar.py
+ tf_upgrade_v2.py --intree ~/code/old --outtree ~/code/new
""")
parser.add_argument(
"--infile",
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
index 7baa1ca..0414bec 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
@@ -116,6 +116,22 @@
self.assertEqual(errors, ["test.py:1: %s requires manual check." % ns])
self.assertIn("loss_reduction has been changed", report)
+ def testDropout(self):
+ text = "tf.nn.dropout(x, keep_prob, name=\"foo\")\n"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(
+ new_text,
+ "tf.nn.dropout(x, 1 - keep_prob, name=\"foo\")\n",
+ )
+
+ text = "tf.nn.dropout(x)\n"
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, text)
+ self.assertEqual(
+ errors,
+ ["test.py:1: tf.nn.dropout requires manual check."]
+ )
+
def testCountNonZeroChanges(self):
text = (
"tf.math.count_nonzero(input_tensor=input, dtype=dtype, name=name, "
@@ -162,6 +178,69 @@
)
self.assertEqual(new_text, expected_text)
+ def testColocateGradientsWithOps(self):
+ text = "tf.gradients(a, foo=False)\n"
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors, [])
+
+ text = "tf.gradients(a, colocate_gradients_with_ops=False)\n"
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors, ["test.py:1: tf.gradients requires manual check."])
+
+ text = "optimizer.minimize(a, foo=False)\n"
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors, [])
+
+ text = "optimizer.minimize(a, colocate_gradients_with_ops=False)\n"
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors,
+ ["test.py:1: Optimizer.minimize requires manual check."])
+
+ text = "optimizer.compute_gradients(a, foo=False)\n"
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors, [])
+
+ text = "optimizer.compute_gradients(a, colocate_gradients_with_ops=False)\n"
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors,
+ ["test.py:1: Optimizer.compute_gradients "
+ "requires manual check."])
+
+ def testExportSavedModelRename(self):
+ text = "self.est.export_savedmodel(path)"
+ _, report, unused_errors, unused_new_text = self._upgrade(text)
+ self.assertIn(
+ "rename the function export_savedmodel() to export_saved_model()",
+ report)
+
+ def testArgmin(self):
+ text = "tf.argmin(input, name=n, dimension=1, output_type=type)"
+ expected_text = "tf.argmin(input=input, name=n, axis=1, output_type=type)"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, expected_text)
+
+ text = "tf.argmin(input, 0)"
+ expected_text = "tf.argmin(input=input, axis=0)"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, expected_text)
+
+ def testArgmax(self):
+ text = "tf.argmax(input, name=n, dimension=1, output_type=type)"
+ expected_text = "tf.argmax(input=input, name=n, axis=1, output_type=type)"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, expected_text)
+
+ text = "tf.argmax(input, 0)"
+ expected_text = "tf.argmax(input=input, axis=0)"
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, expected_text)
+
class TestUpgradeFiles(test_util.TensorFlowTestCase):
diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile
index 205128a..6676de0 100644
--- a/tensorflow/tools/docker/Dockerfile
+++ b/tensorflow/tools/docker/Dockerfile
@@ -1,4 +1,4 @@
-FROM ubuntu:16.04
+FROM ubuntu:18.04
LABEL maintainer="Craig Citro <craigcitro@google.com>"
@@ -8,7 +8,7 @@
curl \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
+ libpng-dev \
libzmq3-dev \
pkg-config \
python \
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index a3893a2..c256dd3 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -1,4 +1,4 @@
-FROM ubuntu:16.04
+FROM ubuntu:18.04
LABEL maintainer="Craig Citro <craigcitro@google.com>"
@@ -9,7 +9,7 @@
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
+ libpng-dev \
libzmq3-dev \
pkg-config \
python-dev \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index bd2883d..2341c0e 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -1,4 +1,4 @@
-FROM ubuntu:16.04
+FROM ubuntu:18.04
LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
@@ -16,7 +16,7 @@
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
+ libpng-dev \
libzmq3-dev \
libssl-dev \
pkg-config \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
index df084e0..5e24617 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
@@ -1,4 +1,4 @@
-FROM ubuntu:16.04
+FROM ubuntu:18.04
LABEL maintainer="Cong Xu <cong.xu@intel.com>"
@@ -16,7 +16,7 @@
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
+ libpng-dev \
libzmq3-dev \
pkg-config \
python-dev \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl
index ac41cff..dad2769 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl
+++ b/tensorflow/tools/docker/Dockerfile.mkl
@@ -1,4 +1,4 @@
-FROM ubuntu:16.04
+FROM ubuntu:18.04
LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
@@ -17,7 +17,7 @@
curl \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
+ libpng-dev \
libzmq3-dev \
pkg-config \
${PYTHON} \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl-horovod b/tensorflow/tools/docker/Dockerfile.mkl-horovod
index 0432cd5..19dc45c 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.mkl-horovod
@@ -1,4 +1,4 @@
-FROM ubuntu:16.04
+FROM ubuntu:18.04
LABEL maintainer="Cong Xu <cong.xu@intel.com>"
@@ -17,7 +17,7 @@
curl \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
+ libpng-dev \
libzmq3-dev \
pkg-config \
python \
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index e164853..34c600a 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -107,6 +107,7 @@
# TensorBoard command, pip will inappropriately remove it during install,
# even though the command is not removed, just moved to a different wheel.
'tensorboard = tensorboard.main:run_main',
+ 'tf_upgrade_v2 = tensorflow.tools.compatibility.tf_upgrade_v2:main',
]
# pylint: enable=line-too-long
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 14a5d0a..d9d4087 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -77,31 +77,31 @@
mkl_repository(
name = "mkl_linux",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
- sha256 = "e2233534a9d15c387e22260997af4312a39e9f86f791768409be273b5453c4e6",
- strip_prefix = "mklml_lnx_2019.0.20180710",
+ sha256 = "f00dc3b142a5be399bdeebd7e7ea369545a35d4fb84c86f98b6b048d72685295",
+ strip_prefix = "mklml_lnx_2019.0.1.20180928",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
- "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_lnx_2019.0.20180710.tgz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.17-rc/mklml_lnx_2019.0.1.20180928.tgz",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.17-rc/mklml_lnx_2019.0.1.20180928.tgz",
],
)
mkl_repository(
name = "mkl_windows",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
- sha256 = "3fdcff17b018a0082491adf3ba143358265336a801646e46e0191ec8d58d24a2",
- strip_prefix = "mklml_win_2019.0.20180710",
+ sha256 = "efef90b7b9613fab10f44c8ac4ff28db613a112c64ed94826d7e44df09c44b0b",
+ strip_prefix = "mklml_win_2019.0.1.20180928",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
- "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_win_2019.0.20180710.zip",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.17-rc/mklml_win_2019.0.1.20180928.zip",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.17-rc/mklml_win_2019.0.1.20180928.zip",
],
)
mkl_repository(
name = "mkl_darwin",
build_file = clean_dep("//third_party/mkl:mkl.BUILD"),
- sha256 = "411a30014a938eb83fb9f37b3dbe8e371b106fc1dd621fc23123cadc72737ce6",
- strip_prefix = "mklml_mac_2019.0.20180710",
+ sha256 = "83f02938a0c095274db7b8b7b694157abafa3837c5cbaef740440d466c86a477",
+ strip_prefix = "mklml_mac_2019.0.1.20180928",
urls = [
- "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
- "https://github.com/intel/mkl-dnn/releases/download/v0.16/mklml_mac_2019.0.20180710.tgz",
+ "https://mirror.bazel.build/github.com/intel/mkl-dnn/releases/download/v0.17-rc/mklml_mac_2019.0.1.20180928.tgz",
+ "https://github.com/intel/mkl-dnn/releases/download/v0.17-rc/mklml_mac_2019.0.1.20180928.tgz",
],
)
@@ -179,15 +179,15 @@
tf_http_archive(
name = "com_github_googlecloudplatform_google_cloud_cpp",
- sha256 = "fdd3b3aecce60987e5525e55bf3a21d68a8695320bd5b980775af6507eec3944",
- strip_prefix = "google-cloud-cpp-14760a86c4ffab9943b476305c4fe927ad95db1c",
+ sha256 = "3ade2072e6588ff56c0434abe6c63aa5f3f2d56be15a299bafc7e9cdf0a12c17",
+ strip_prefix = "google-cloud-cpp-0.3.0",
system_build_file = clean_dep("//third_party/systemlibs:google_cloud_cpp.BUILD"),
system_link_files = {
"//third_party/systemlibs:google_cloud_cpp.google.cloud.bigtable.BUILD": "google/cloud/bigtable/BUILD",
},
urls = [
- "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
- "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/14760a86c4ffab9943b476305c4fe927ad95db1c.tar.gz",
+ "https://mirror.bazel.build/github.com/GoogleCloudPlatform/google-cloud-cpp/archive/v0.3.0.tar.gz",
+ "https://github.com/GoogleCloudPlatform/google-cloud-cpp/archive/v0.3.0.tar.gz",
],
)
@@ -472,11 +472,11 @@
tf_http_archive(
name = "llvm",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
- sha256 = "286465fc41ade5c1c44e4a6dce9681106664fcdd12264dc9be63fc22bbee3c9c",
- strip_prefix = "llvm-0478924a3727c74fd482d07eed45a8347540576e",
+ sha256 = "7b4f705c532ee2aafb6e8b9013ad22ec8bb1823a153cd2d6ddb6b7faef818874",
+ strip_prefix = "llvm-9ad322c7dfd4385be9a515d734f70700f192ebae",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/0478924a3727c74fd482d07eed45a8347540576e.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/0478924a3727c74fd482d07eed45a8347540576e.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/9ad322c7dfd4385be9a515d734f70700f192ebae.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/9ad322c7dfd4385be9a515d734f70700f192ebae.tar.gz",
],
)
@@ -689,11 +689,11 @@
tf_http_archive(
name = "arm_neon_2_x86_sse",
build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
- sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5",
- strip_prefix = "ARM_NEON_2_x86_SSE-0f77d9d182265259b135dad949230ecbf1a2633d",
+ sha256 = "213733991310b904b11b053ac224fee2d4e0179e46b52fe7f8735b8831e04dcc",
+ strip_prefix = "ARM_NEON_2_x86_SSE-1200fe90bb174a6224a525ee60148671a786a71f",
urls = [
- "https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
- "https://github.com/intel/ARM_NEON_2_x86_SSE/archive/0f77d9d182265259b135dad949230ecbf1a2633d.tar.gz",
+ "https://mirror.bazel.build/github.com/intel/ARM_NEON_2_x86_SSE/archive/1200fe90bb174a6224a525ee60148671a786a71f.tar.gz",
+ "https://github.com/intel/ARM_NEON_2_x86_SSE/archive/1200fe90bb174a6224a525ee60148671a786a71f.tar.gz",
],
)
diff --git a/third_party/libxsmm.BUILD b/third_party/libxsmm.BUILD
index ee49d28..dc7dcc9 100644
--- a/third_party/libxsmm.BUILD
+++ b/third_party/libxsmm.BUILD
@@ -38,8 +38,8 @@
":libxsmm_interface",
],
visibility = [
- "//third_party/eigen3:__pkg__",
"//tensorflow/core/kernels:__pkg__",
+ "//third_party/eigen3:__pkg__",
],
)
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
index a7b4687..9da417f 100644
--- a/third_party/toolchains/BUILD
+++ b/third_party/toolchains/BUILD
@@ -35,3 +35,16 @@
value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@%s"
}""" % container_digests["cuda9.0-cudnn7-ubuntu14.04"],
)
+
+platform(
+ name = "rbe_cuda10.0-cudnn7-ubuntu14.04",
+ constraint_values = [
+ "@bazel_tools//platforms:x86_64",
+ "@bazel_tools//platforms:linux",
+ ],
+ remote_execution_properties = """
+ properties: {
+ name: "container-image"
+ value:"docker://gcr.io/asci-toolchain/nosla-cuda10.0-cudnn7-ubuntu14.04@%s"
+ }""" % container_digests["cuda10.0-cudnn7-ubuntu14.04"],
+)
diff --git a/third_party/toolchains/preconfig/generate/containers.bzl b/third_party/toolchains/preconfig/generate/containers.bzl
index 1f9e29d..0309b8f 100644
--- a/third_party/toolchains/preconfig/generate/containers.bzl
+++ b/third_party/toolchains/preconfig/generate/containers.bzl
@@ -1,4 +1,4 @@
container_digests = {
"cuda9.0-cudnn7-ubuntu14.04": "sha256:c26138f4c38c754da2bad44a8a068523abf7fbd71d58a57ce92e5342c5431bf5",
- "cuda10.0-cudnn7-ubuntu14.04": "sha256:34c4a55e2376b300cdc2b903775fc32e62352f6e33f927df5653743324378bfc",
+ "cuda10.0-cudnn7-ubuntu14.04": "sha256:7737d770599de8435115bfdf56977002319316a6735ab081f82506cb51443f9d",
}
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/WORKSPACE b/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/WORKSPACE
new file mode 100644
index 0000000..b61f572
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/WORKSPACE
@@ -0,0 +1,2 @@
+# DO NOT EDIT: automatically generated WORKSPACE file for cuda_configure rule
+workspace(name = "local_config_cuda")
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
new file mode 100755
index 0000000..c813efc
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
@@ -0,0 +1,1275 @@
+licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "using_nvcc",
+ values = {
+ "define": "using_cuda_nvcc=true",
+ },
+)
+
+config_setting(
+ name = "using_clang",
+ values = {
+ "define": "using_cuda_clang=true",
+ },
+)
+
+# Equivalent to using_clang && -c opt.
+config_setting(
+ name = "using_clang_opt",
+ values = {
+ "define": "using_cuda_clang=true",
+ "compilation_mode": "opt",
+ },
+)
+
+config_setting(
+ name = "darwin",
+ values = {"cpu": "darwin"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "freebsd",
+ values = {"cpu": "freebsd"},
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda_headers",
+ hdrs = [
+ "cuda/cuda_config.h",
+ ":cuda-include",
+ ":cudnn-include",
+ ],
+ includes = [
+ ".",
+ "cuda/include",
+ "cuda/include/crt",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudart_static",
+ srcs = ["cuda/lib/libcudart_static.a"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkopts = select({
+ ":freebsd": [],
+ "//conditions:default": ["-ldl"],
+ }) + [
+ "-lpthread",
+ "-lrt",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda_driver",
+ srcs = ["cuda/lib/libcuda.so"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudart",
+ srcs = ["cuda/lib/libcudart.so.10.0"],
+ data = ["cuda/lib/libcudart.so.10.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cublas",
+ srcs = ["cuda/lib/libcublas.so.10.0"],
+ data = ["cuda/lib/libcublas.so.10.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cusolver",
+ srcs = ["cuda/lib/libcusolver.so.10.0"],
+ data = ["cuda/lib/libcusolver.so.10.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkopts = ["-lgomp"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudnn",
+ srcs = ["cuda/lib/libcudnn.so.7"],
+ data = ["cuda/lib/libcudnn.so.7"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudnn_header",
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cufft",
+ srcs = ["cuda/lib/libcufft.so.10.0"],
+ data = ["cuda/lib/libcufft.so.10.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "curand",
+ srcs = ["cuda/lib/libcurand.so.10.0"],
+ data = ["cuda/lib/libcurand.so.10.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cublas",
+ ":cuda_headers",
+ ":cudart",
+ ":cudnn",
+ ":cufft",
+ ":curand",
+ ],
+)
+
+cc_library(
+ name = "cupti_headers",
+ hdrs = [
+ "cuda/cuda_config.h",
+ ":cuda-extras",
+ ],
+ includes = [
+ ".",
+ "cuda/extras/CUPTI/include/",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cupti_dsos",
+ data = ["cuda/lib/libcupti.so.10.0"],
+ includes = [
+ ".",
+ "cuda/include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "libdevice_root",
+ data = [":cuda-nvvm"],
+ visibility = ["//visibility:public"],
+)
+
+genrule(
+ name = "cuda-include",
+ outs = [
+ "cuda/include/CL/cl.h",
+ "cuda/include/CL/cl.hpp",
+ "cuda/include/CL/cl_egl.h",
+ "cuda/include/CL/cl_ext.h",
+ "cuda/include/CL/cl_gl.h",
+ "cuda/include/CL/cl_gl_ext.h",
+ "cuda/include/CL/cl_platform.h",
+ "cuda/include/CL/opencl.h",
+ "cuda/include/builtin_types.h",
+ "cuda/include/channel_descriptor.h",
+ "cuda/include/common_functions.h",
+ "cuda/include/cooperative_groups.h",
+ "cuda/include/cooperative_groups_helpers.h",
+ "cuda/include/crt/common_functions.h",
+ "cuda/include/crt/device_double_functions.h",
+ "cuda/include/crt/device_double_functions.hpp",
+ "cuda/include/crt/device_functions.h",
+ "cuda/include/crt/device_functions.hpp",
+ "cuda/include/crt/func_macro.h",
+ "cuda/include/crt/host_config.h",
+ "cuda/include/crt/host_defines.h",
+ "cuda/include/crt/host_runtime.h",
+ "cuda/include/crt/math_functions.h",
+ "cuda/include/crt/math_functions.hpp",
+ "cuda/include/crt/mma.h",
+ "cuda/include/crt/mma.hpp",
+ "cuda/include/crt/nvfunctional",
+ "cuda/include/crt/sm_70_rt.h",
+ "cuda/include/crt/sm_70_rt.hpp",
+ "cuda/include/crt/storage_class.h",
+ "cuda/include/cuComplex.h",
+ "cuda/include/cublas.h",
+ "cuda/include/cublasXt.h",
+ "cuda/include/cublas_api.h",
+ "cuda/include/cublas_v2.h",
+ "cuda/include/cuda.h",
+ "cuda/include/cudaEGL.h",
+ "cuda/include/cudaGL.h",
+ "cuda/include/cudaProfiler.h",
+ "cuda/include/cudaVDPAU.h",
+ "cuda/include/cuda_device_runtime_api.h",
+ "cuda/include/cuda_egl_interop.h",
+ "cuda/include/cuda_fp16.h",
+ "cuda/include/cuda_fp16.hpp",
+ "cuda/include/cuda_gl_interop.h",
+ "cuda/include/cuda_occupancy.h",
+ "cuda/include/cuda_profiler_api.h",
+ "cuda/include/cuda_runtime.h",
+ "cuda/include/cuda_runtime_api.h",
+ "cuda/include/cuda_surface_types.h",
+ "cuda/include/cuda_texture_types.h",
+ "cuda/include/cuda_vdpau_interop.h",
+ "cuda/include/cudalibxt.h",
+ "cuda/include/cudart_platform.h",
+ "cuda/include/cufft.h",
+ "cuda/include/cufftXt.h",
+ "cuda/include/cufftw.h",
+ "cuda/include/curand.h",
+ "cuda/include/curand_discrete.h",
+ "cuda/include/curand_discrete2.h",
+ "cuda/include/curand_globals.h",
+ "cuda/include/curand_kernel.h",
+ "cuda/include/curand_lognormal.h",
+ "cuda/include/curand_mrg32k3a.h",
+ "cuda/include/curand_mtgp32.h",
+ "cuda/include/curand_mtgp32_host.h",
+ "cuda/include/curand_mtgp32_kernel.h",
+ "cuda/include/curand_mtgp32dc_p_11213.h",
+ "cuda/include/curand_normal.h",
+ "cuda/include/curand_normal_static.h",
+ "cuda/include/curand_philox4x32_x.h",
+ "cuda/include/curand_poisson.h",
+ "cuda/include/curand_precalc.h",
+ "cuda/include/curand_uniform.h",
+ "cuda/include/cusolverDn.h",
+ "cuda/include/cusolverRf.h",
+ "cuda/include/cusolverSp.h",
+ "cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h",
+ "cuda/include/cusolver_common.h",
+ "cuda/include/cusparse.h",
+ "cuda/include/cusparse_v2.h",
+ "cuda/include/device_atomic_functions.h",
+ "cuda/include/device_atomic_functions.hpp",
+ "cuda/include/device_double_functions.h",
+ "cuda/include/device_functions.h",
+ "cuda/include/device_launch_parameters.h",
+ "cuda/include/device_types.h",
+ "cuda/include/driver_functions.h",
+ "cuda/include/driver_types.h",
+ "cuda/include/fatBinaryCtl.h",
+ "cuda/include/fatbinary.h",
+ "cuda/include/host_config.h",
+ "cuda/include/host_defines.h",
+ "cuda/include/library_types.h",
+ "cuda/include/math_constants.h",
+ "cuda/include/math_functions.h",
+ "cuda/include/mma.h",
+ "cuda/include/npp.h",
+ "cuda/include/nppcore.h",
+ "cuda/include/nppdefs.h",
+ "cuda/include/nppi.h",
+ "cuda/include/nppi_arithmetic_and_logical_operations.h",
+ "cuda/include/nppi_color_conversion.h",
+ "cuda/include/nppi_compression_functions.h",
+ "cuda/include/nppi_computer_vision.h",
+ "cuda/include/nppi_data_exchange_and_initialization.h",
+ "cuda/include/nppi_filtering_functions.h",
+ "cuda/include/nppi_geometry_transforms.h",
+ "cuda/include/nppi_linear_transforms.h",
+ "cuda/include/nppi_morphological_operations.h",
+ "cuda/include/nppi_statistics_functions.h",
+ "cuda/include/nppi_support_functions.h",
+ "cuda/include/nppi_threshold_and_compare_operations.h",
+ "cuda/include/npps.h",
+ "cuda/include/npps_arithmetic_and_logical_operations.h",
+ "cuda/include/npps_conversion_functions.h",
+ "cuda/include/npps_filtering_functions.h",
+ "cuda/include/npps_initialization.h",
+ "cuda/include/npps_statistics_functions.h",
+ "cuda/include/npps_support_functions.h",
+ "cuda/include/nppversion.h",
+ "cuda/include/nvToolsExt.h",
+ "cuda/include/nvToolsExtCuda.h",
+ "cuda/include/nvToolsExtCudaRt.h",
+ "cuda/include/nvToolsExtMeta.h",
+ "cuda/include/nvToolsExtSync.h",
+ "cuda/include/nvblas.h",
+ "cuda/include/nvfunctional",
+ "cuda/include/nvgraph.h",
+ "cuda/include/nvjpeg.h",
+ "cuda/include/nvml.h",
+ "cuda/include/nvrtc.h",
+ "cuda/include/nvtx3/nvToolsExt.h",
+ "cuda/include/nvtx3/nvToolsExtCuda.h",
+ "cuda/include/nvtx3/nvToolsExtCudaRt.h",
+ "cuda/include/nvtx3/nvToolsExtOpenCL.h",
+ "cuda/include/nvtx3/nvToolsExtSync.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxImpl.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxImplCore.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxInit.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxInitDecls.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxInitDefs.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxLinkOnce.h",
+ "cuda/include/nvtx3/nvtxDetail/nvtxTypes.h",
+ "cuda/include/sm_20_atomic_functions.h",
+ "cuda/include/sm_20_atomic_functions.hpp",
+ "cuda/include/sm_20_intrinsics.h",
+ "cuda/include/sm_20_intrinsics.hpp",
+ "cuda/include/sm_30_intrinsics.h",
+ "cuda/include/sm_30_intrinsics.hpp",
+ "cuda/include/sm_32_atomic_functions.h",
+ "cuda/include/sm_32_atomic_functions.hpp",
+ "cuda/include/sm_32_intrinsics.h",
+ "cuda/include/sm_32_intrinsics.hpp",
+ "cuda/include/sm_35_atomic_functions.h",
+ "cuda/include/sm_35_intrinsics.h",
+ "cuda/include/sm_60_atomic_functions.h",
+ "cuda/include/sm_60_atomic_functions.hpp",
+ "cuda/include/sm_61_intrinsics.h",
+ "cuda/include/sm_61_intrinsics.hpp",
+ "cuda/include/sobol_direction_vectors.h",
+ "cuda/include/surface_functions.h",
+ "cuda/include/surface_functions.hpp",
+ "cuda/include/surface_indirect_functions.h",
+ "cuda/include/surface_indirect_functions.hpp",
+ "cuda/include/surface_types.h",
+ "cuda/include/texture_fetch_functions.h",
+ "cuda/include/texture_fetch_functions.hpp",
+ "cuda/include/texture_indirect_functions.h",
+ "cuda/include/texture_indirect_functions.hpp",
+ "cuda/include/texture_types.h",
+ "cuda/include/thrust/adjacent_difference.h",
+ "cuda/include/thrust/advance.h",
+ "cuda/include/thrust/binary_search.h",
+ "cuda/include/thrust/complex.h",
+ "cuda/include/thrust/copy.h",
+ "cuda/include/thrust/count.h",
+ "cuda/include/thrust/detail/adjacent_difference.inl",
+ "cuda/include/thrust/detail/advance.inl",
+ "cuda/include/thrust/detail/alignment.h",
+ "cuda/include/thrust/detail/allocator/allocator_traits.h",
+ "cuda/include/thrust/detail/allocator/allocator_traits.inl",
+ "cuda/include/thrust/detail/allocator/copy_construct_range.h",
+ "cuda/include/thrust/detail/allocator/copy_construct_range.inl",
+ "cuda/include/thrust/detail/allocator/default_construct_range.h",
+ "cuda/include/thrust/detail/allocator/default_construct_range.inl",
+ "cuda/include/thrust/detail/allocator/destroy_range.h",
+ "cuda/include/thrust/detail/allocator/destroy_range.inl",
+ "cuda/include/thrust/detail/allocator/fill_construct_range.h",
+ "cuda/include/thrust/detail/allocator/fill_construct_range.inl",
+ "cuda/include/thrust/detail/allocator/malloc_allocator.h",
+ "cuda/include/thrust/detail/allocator/malloc_allocator.inl",
+ "cuda/include/thrust/detail/allocator/no_throw_allocator.h",
+ "cuda/include/thrust/detail/allocator/tagged_allocator.h",
+ "cuda/include/thrust/detail/allocator/tagged_allocator.inl",
+ "cuda/include/thrust/detail/allocator/temporary_allocator.h",
+ "cuda/include/thrust/detail/allocator/temporary_allocator.inl",
+ "cuda/include/thrust/detail/binary_search.inl",
+ "cuda/include/thrust/detail/complex/arithmetic.h",
+ "cuda/include/thrust/detail/complex/c99math.h",
+ "cuda/include/thrust/detail/complex/catrig.h",
+ "cuda/include/thrust/detail/complex/catrigf.h",
+ "cuda/include/thrust/detail/complex/ccosh.h",
+ "cuda/include/thrust/detail/complex/ccoshf.h",
+ "cuda/include/thrust/detail/complex/cexp.h",
+ "cuda/include/thrust/detail/complex/cexpf.h",
+ "cuda/include/thrust/detail/complex/clog.h",
+ "cuda/include/thrust/detail/complex/clogf.h",
+ "cuda/include/thrust/detail/complex/complex.inl",
+ "cuda/include/thrust/detail/complex/cpow.h",
+ "cuda/include/thrust/detail/complex/cproj.h",
+ "cuda/include/thrust/detail/complex/csinh.h",
+ "cuda/include/thrust/detail/complex/csinhf.h",
+ "cuda/include/thrust/detail/complex/csqrt.h",
+ "cuda/include/thrust/detail/complex/csqrtf.h",
+ "cuda/include/thrust/detail/complex/ctanh.h",
+ "cuda/include/thrust/detail/complex/ctanhf.h",
+ "cuda/include/thrust/detail/complex/math_private.h",
+ "cuda/include/thrust/detail/complex/stream.h",
+ "cuda/include/thrust/detail/config.h",
+ "cuda/include/thrust/detail/config/compiler.h",
+ "cuda/include/thrust/detail/config/compiler_fence.h",
+ "cuda/include/thrust/detail/config/config.h",
+ "cuda/include/thrust/detail/config/debug.h",
+ "cuda/include/thrust/detail/config/device_system.h",
+ "cuda/include/thrust/detail/config/exec_check_disable.h",
+ "cuda/include/thrust/detail/config/forceinline.h",
+ "cuda/include/thrust/detail/config/global_workarounds.h",
+ "cuda/include/thrust/detail/config/host_device.h",
+ "cuda/include/thrust/detail/config/host_system.h",
+ "cuda/include/thrust/detail/config/simple_defines.h",
+ "cuda/include/thrust/detail/contiguous_storage.h",
+ "cuda/include/thrust/detail/contiguous_storage.inl",
+ "cuda/include/thrust/detail/copy.h",
+ "cuda/include/thrust/detail/copy.inl",
+ "cuda/include/thrust/detail/copy_if.h",
+ "cuda/include/thrust/detail/copy_if.inl",
+ "cuda/include/thrust/detail/count.inl",
+ "cuda/include/thrust/detail/cstdint.h",
+ "cuda/include/thrust/detail/device_delete.inl",
+ "cuda/include/thrust/detail/device_free.inl",
+ "cuda/include/thrust/detail/device_malloc.inl",
+ "cuda/include/thrust/detail/device_new.inl",
+ "cuda/include/thrust/detail/device_ptr.inl",
+ "cuda/include/thrust/detail/device_reference.inl",
+ "cuda/include/thrust/detail/device_vector.inl",
+ "cuda/include/thrust/detail/dispatch/is_trivial_copy.h",
+ "cuda/include/thrust/detail/distance.inl",
+ "cuda/include/thrust/detail/equal.inl",
+ "cuda/include/thrust/detail/execute_with_allocator.h",
+ "cuda/include/thrust/detail/execution_policy.h",
+ "cuda/include/thrust/detail/extrema.inl",
+ "cuda/include/thrust/detail/fill.inl",
+ "cuda/include/thrust/detail/find.inl",
+ "cuda/include/thrust/detail/for_each.inl",
+ "cuda/include/thrust/detail/function.h",
+ "cuda/include/thrust/detail/functional.inl",
+ "cuda/include/thrust/detail/functional/actor.h",
+ "cuda/include/thrust/detail/functional/actor.inl",
+ "cuda/include/thrust/detail/functional/argument.h",
+ "cuda/include/thrust/detail/functional/composite.h",
+ "cuda/include/thrust/detail/functional/operators.h",
+ "cuda/include/thrust/detail/functional/operators/arithmetic_operators.h",
+ "cuda/include/thrust/detail/functional/operators/assignment_operator.h",
+ "cuda/include/thrust/detail/functional/operators/bitwise_operators.h",
+ "cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h",
+ "cuda/include/thrust/detail/functional/operators/logical_operators.h",
+ "cuda/include/thrust/detail/functional/operators/operator_adaptors.h",
+ "cuda/include/thrust/detail/functional/operators/relational_operators.h",
+ "cuda/include/thrust/detail/functional/placeholder.h",
+ "cuda/include/thrust/detail/functional/value.h",
+ "cuda/include/thrust/detail/gather.inl",
+ "cuda/include/thrust/detail/generate.inl",
+ "cuda/include/thrust/detail/get_iterator_value.h",
+ "cuda/include/thrust/detail/host_vector.inl",
+ "cuda/include/thrust/detail/inner_product.inl",
+ "cuda/include/thrust/detail/integer_math.h",
+ "cuda/include/thrust/detail/integer_traits.h",
+ "cuda/include/thrust/detail/internal_functional.h",
+ "cuda/include/thrust/detail/logical.inl",
+ "cuda/include/thrust/detail/malloc_and_free.h",
+ "cuda/include/thrust/detail/merge.inl",
+ "cuda/include/thrust/detail/minmax.h",
+ "cuda/include/thrust/detail/mismatch.inl",
+ "cuda/include/thrust/detail/mpl/math.h",
+ "cuda/include/thrust/detail/numeric_traits.h",
+ "cuda/include/thrust/detail/overlapped_copy.h",
+ "cuda/include/thrust/detail/pair.inl",
+ "cuda/include/thrust/detail/partition.inl",
+ "cuda/include/thrust/detail/pointer.h",
+ "cuda/include/thrust/detail/pointer.inl",
+ "cuda/include/thrust/detail/preprocessor.h",
+ "cuda/include/thrust/detail/range/head_flags.h",
+ "cuda/include/thrust/detail/range/tail_flags.h",
+ "cuda/include/thrust/detail/raw_pointer_cast.h",
+ "cuda/include/thrust/detail/raw_reference_cast.h",
+ "cuda/include/thrust/detail/reduce.inl",
+ "cuda/include/thrust/detail/reference.h",
+ "cuda/include/thrust/detail/reference.inl",
+ "cuda/include/thrust/detail/reference_forward_declaration.h",
+ "cuda/include/thrust/detail/remove.inl",
+ "cuda/include/thrust/detail/replace.inl",
+ "cuda/include/thrust/detail/reverse.inl",
+ "cuda/include/thrust/detail/scan.inl",
+ "cuda/include/thrust/detail/scatter.inl",
+ "cuda/include/thrust/detail/seq.h",
+ "cuda/include/thrust/detail/sequence.inl",
+ "cuda/include/thrust/detail/set_operations.inl",
+ "cuda/include/thrust/detail/sort.inl",
+ "cuda/include/thrust/detail/static_assert.h",
+ "cuda/include/thrust/detail/static_map.h",
+ "cuda/include/thrust/detail/swap.h",
+ "cuda/include/thrust/detail/swap.inl",
+ "cuda/include/thrust/detail/swap_ranges.inl",
+ "cuda/include/thrust/detail/tabulate.inl",
+ "cuda/include/thrust/detail/temporary_array.h",
+ "cuda/include/thrust/detail/temporary_array.inl",
+ "cuda/include/thrust/detail/temporary_buffer.h",
+ "cuda/include/thrust/detail/transform.inl",
+ "cuda/include/thrust/detail/transform_reduce.inl",
+ "cuda/include/thrust/detail/transform_scan.inl",
+ "cuda/include/thrust/detail/trivial_sequence.h",
+ "cuda/include/thrust/detail/tuple.inl",
+ "cuda/include/thrust/detail/tuple_meta_transform.h",
+ "cuda/include/thrust/detail/tuple_transform.h",
+ "cuda/include/thrust/detail/type_traits.h",
+ "cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h",
+ "cuda/include/thrust/detail/type_traits/function_traits.h",
+ "cuda/include/thrust/detail/type_traits/has_member_function.h",
+ "cuda/include/thrust/detail/type_traits/has_nested_type.h",
+ "cuda/include/thrust/detail/type_traits/has_trivial_assign.h",
+ "cuda/include/thrust/detail/type_traits/is_call_possible.h",
+ "cuda/include/thrust/detail/type_traits/is_metafunction_defined.h",
+ "cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h",
+ "cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h",
+ "cuda/include/thrust/detail/type_traits/minimum_type.h",
+ "cuda/include/thrust/detail/type_traits/pointer_traits.h",
+ "cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h",
+ "cuda/include/thrust/detail/uninitialized_copy.inl",
+ "cuda/include/thrust/detail/uninitialized_fill.inl",
+ "cuda/include/thrust/detail/unique.inl",
+ "cuda/include/thrust/detail/use_default.h",
+ "cuda/include/thrust/detail/util/align.h",
+ "cuda/include/thrust/detail/util/blocking.h",
+ "cuda/include/thrust/detail/vector_base.h",
+ "cuda/include/thrust/detail/vector_base.inl",
+ "cuda/include/thrust/device_allocator.h",
+ "cuda/include/thrust/device_delete.h",
+ "cuda/include/thrust/device_free.h",
+ "cuda/include/thrust/device_malloc.h",
+ "cuda/include/thrust/device_malloc_allocator.h",
+ "cuda/include/thrust/device_new.h",
+ "cuda/include/thrust/device_new_allocator.h",
+ "cuda/include/thrust/device_ptr.h",
+ "cuda/include/thrust/device_reference.h",
+ "cuda/include/thrust/device_vector.h",
+ "cuda/include/thrust/distance.h",
+ "cuda/include/thrust/equal.h",
+ "cuda/include/thrust/execution_policy.h",
+ "cuda/include/thrust/extrema.h",
+ "cuda/include/thrust/fill.h",
+ "cuda/include/thrust/find.h",
+ "cuda/include/thrust/for_each.h",
+ "cuda/include/thrust/functional.h",
+ "cuda/include/thrust/gather.h",
+ "cuda/include/thrust/generate.h",
+ "cuda/include/thrust/host_vector.h",
+ "cuda/include/thrust/inner_product.h",
+ "cuda/include/thrust/iterator/constant_iterator.h",
+ "cuda/include/thrust/iterator/counting_iterator.h",
+ "cuda/include/thrust/iterator/detail/any_assign.h",
+ "cuda/include/thrust/iterator/detail/any_system_tag.h",
+ "cuda/include/thrust/iterator/detail/constant_iterator_base.h",
+ "cuda/include/thrust/iterator/detail/counting_iterator.inl",
+ "cuda/include/thrust/iterator/detail/device_system_tag.h",
+ "cuda/include/thrust/iterator/detail/discard_iterator_base.h",
+ "cuda/include/thrust/iterator/detail/distance_from_result.h",
+ "cuda/include/thrust/iterator/detail/host_system_tag.h",
+ "cuda/include/thrust/iterator/detail/is_iterator_category.h",
+ "cuda/include/thrust/iterator/detail/is_trivial_iterator.h",
+ "cuda/include/thrust/iterator/detail/iterator_adaptor_base.h",
+ "cuda/include/thrust/iterator/detail/iterator_category_to_system.h",
+ "cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h",
+ "cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h",
+ "cuda/include/thrust/iterator/detail/iterator_facade_category.h",
+ "cuda/include/thrust/iterator/detail/iterator_traits.inl",
+ "cuda/include/thrust/iterator/detail/iterator_traversal_tags.h",
+ "cuda/include/thrust/iterator/detail/join_iterator.h",
+ "cuda/include/thrust/iterator/detail/minimum_category.h",
+ "cuda/include/thrust/iterator/detail/minimum_system.h",
+ "cuda/include/thrust/iterator/detail/normal_iterator.h",
+ "cuda/include/thrust/iterator/detail/permutation_iterator_base.h",
+ "cuda/include/thrust/iterator/detail/retag.h",
+ "cuda/include/thrust/iterator/detail/reverse_iterator.inl",
+ "cuda/include/thrust/iterator/detail/reverse_iterator_base.h",
+ "cuda/include/thrust/iterator/detail/tagged_iterator.h",
+ "cuda/include/thrust/iterator/detail/transform_iterator.inl",
+ "cuda/include/thrust/iterator/detail/transform_output_iterator.inl",
+ "cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h",
+ "cuda/include/thrust/iterator/detail/universal_categories.h",
+ "cuda/include/thrust/iterator/detail/zip_iterator.inl",
+ "cuda/include/thrust/iterator/detail/zip_iterator_base.h",
+ "cuda/include/thrust/iterator/discard_iterator.h",
+ "cuda/include/thrust/iterator/iterator_adaptor.h",
+ "cuda/include/thrust/iterator/iterator_categories.h",
+ "cuda/include/thrust/iterator/iterator_facade.h",
+ "cuda/include/thrust/iterator/iterator_traits.h",
+ "cuda/include/thrust/iterator/permutation_iterator.h",
+ "cuda/include/thrust/iterator/retag.h",
+ "cuda/include/thrust/iterator/reverse_iterator.h",
+ "cuda/include/thrust/iterator/transform_iterator.h",
+ "cuda/include/thrust/iterator/transform_output_iterator.h",
+ "cuda/include/thrust/iterator/zip_iterator.h",
+ "cuda/include/thrust/logical.h",
+ "cuda/include/thrust/memory.h",
+ "cuda/include/thrust/merge.h",
+ "cuda/include/thrust/mismatch.h",
+ "cuda/include/thrust/pair.h",
+ "cuda/include/thrust/partition.h",
+ "cuda/include/thrust/random.h",
+ "cuda/include/thrust/random/detail/discard_block_engine.inl",
+ "cuda/include/thrust/random/detail/linear_congruential_engine.inl",
+ "cuda/include/thrust/random/detail/linear_congruential_engine_discard.h",
+ "cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl",
+ "cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h",
+ "cuda/include/thrust/random/detail/mod.h",
+ "cuda/include/thrust/random/detail/normal_distribution.inl",
+ "cuda/include/thrust/random/detail/normal_distribution_base.h",
+ "cuda/include/thrust/random/detail/random_core_access.h",
+ "cuda/include/thrust/random/detail/subtract_with_carry_engine.inl",
+ "cuda/include/thrust/random/detail/uniform_int_distribution.inl",
+ "cuda/include/thrust/random/detail/uniform_real_distribution.inl",
+ "cuda/include/thrust/random/detail/xor_combine_engine.inl",
+ "cuda/include/thrust/random/detail/xor_combine_engine_max.h",
+ "cuda/include/thrust/random/discard_block_engine.h",
+ "cuda/include/thrust/random/linear_congruential_engine.h",
+ "cuda/include/thrust/random/linear_feedback_shift_engine.h",
+ "cuda/include/thrust/random/normal_distribution.h",
+ "cuda/include/thrust/random/subtract_with_carry_engine.h",
+ "cuda/include/thrust/random/uniform_int_distribution.h",
+ "cuda/include/thrust/random/uniform_real_distribution.h",
+ "cuda/include/thrust/random/xor_combine_engine.h",
+ "cuda/include/thrust/reduce.h",
+ "cuda/include/thrust/remove.h",
+ "cuda/include/thrust/replace.h",
+ "cuda/include/thrust/reverse.h",
+ "cuda/include/thrust/scan.h",
+ "cuda/include/thrust/scatter.h",
+ "cuda/include/thrust/sequence.h",
+ "cuda/include/thrust/set_operations.h",
+ "cuda/include/thrust/sort.h",
+ "cuda/include/thrust/swap.h",
+ "cuda/include/thrust/system/cpp/detail/adjacent_difference.h",
+ "cuda/include/thrust/system/cpp/detail/assign_value.h",
+ "cuda/include/thrust/system/cpp/detail/binary_search.h",
+ "cuda/include/thrust/system/cpp/detail/copy.h",
+ "cuda/include/thrust/system/cpp/detail/copy_if.h",
+ "cuda/include/thrust/system/cpp/detail/count.h",
+ "cuda/include/thrust/system/cpp/detail/equal.h",
+ "cuda/include/thrust/system/cpp/detail/execution_policy.h",
+ "cuda/include/thrust/system/cpp/detail/extrema.h",
+ "cuda/include/thrust/system/cpp/detail/fill.h",
+ "cuda/include/thrust/system/cpp/detail/find.h",
+ "cuda/include/thrust/system/cpp/detail/for_each.h",
+ "cuda/include/thrust/system/cpp/detail/gather.h",
+ "cuda/include/thrust/system/cpp/detail/generate.h",
+ "cuda/include/thrust/system/cpp/detail/get_value.h",
+ "cuda/include/thrust/system/cpp/detail/inner_product.h",
+ "cuda/include/thrust/system/cpp/detail/iter_swap.h",
+ "cuda/include/thrust/system/cpp/detail/logical.h",
+ "cuda/include/thrust/system/cpp/detail/malloc_and_free.h",
+ "cuda/include/thrust/system/cpp/detail/memory.inl",
+ "cuda/include/thrust/system/cpp/detail/merge.h",
+ "cuda/include/thrust/system/cpp/detail/mismatch.h",
+ "cuda/include/thrust/system/cpp/detail/par.h",
+ "cuda/include/thrust/system/cpp/detail/partition.h",
+ "cuda/include/thrust/system/cpp/detail/reduce.h",
+ "cuda/include/thrust/system/cpp/detail/reduce_by_key.h",
+ "cuda/include/thrust/system/cpp/detail/remove.h",
+ "cuda/include/thrust/system/cpp/detail/replace.h",
+ "cuda/include/thrust/system/cpp/detail/reverse.h",
+ "cuda/include/thrust/system/cpp/detail/scan.h",
+ "cuda/include/thrust/system/cpp/detail/scan_by_key.h",
+ "cuda/include/thrust/system/cpp/detail/scatter.h",
+ "cuda/include/thrust/system/cpp/detail/sequence.h",
+ "cuda/include/thrust/system/cpp/detail/set_operations.h",
+ "cuda/include/thrust/system/cpp/detail/sort.h",
+ "cuda/include/thrust/system/cpp/detail/swap_ranges.h",
+ "cuda/include/thrust/system/cpp/detail/tabulate.h",
+ "cuda/include/thrust/system/cpp/detail/temporary_buffer.h",
+ "cuda/include/thrust/system/cpp/detail/transform.h",
+ "cuda/include/thrust/system/cpp/detail/transform_reduce.h",
+ "cuda/include/thrust/system/cpp/detail/transform_scan.h",
+ "cuda/include/thrust/system/cpp/detail/uninitialized_copy.h",
+ "cuda/include/thrust/system/cpp/detail/uninitialized_fill.h",
+ "cuda/include/thrust/system/cpp/detail/unique.h",
+ "cuda/include/thrust/system/cpp/detail/unique_by_key.h",
+ "cuda/include/thrust/system/cpp/detail/vector.inl",
+ "cuda/include/thrust/system/cpp/execution_policy.h",
+ "cuda/include/thrust/system/cpp/memory.h",
+ "cuda/include/thrust/system/cpp/vector.h",
+ "cuda/include/thrust/system/cuda/config.h",
+ "cuda/include/thrust/system/cuda/detail/adjacent_difference.h",
+ "cuda/include/thrust/system/cuda/detail/assign_value.h",
+ "cuda/include/thrust/system/cuda/detail/binary_search.h",
+ "cuda/include/thrust/system/cuda/detail/copy.h",
+ "cuda/include/thrust/system/cuda/detail/copy_if.h",
+ "cuda/include/thrust/system/cuda/detail/core/agent_launcher.h",
+ "cuda/include/thrust/system/cuda/detail/core/alignment.h",
+ "cuda/include/thrust/system/cuda/detail/core/triple_chevron_launch.h",
+ "cuda/include/thrust/system/cuda/detail/core/util.h",
+ "cuda/include/thrust/system/cuda/detail/count.h",
+ "cuda/include/thrust/system/cuda/detail/cross_system.h",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/cub.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/host/mutex.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_device.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/util_type.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh",
+ "cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh",
+ "cuda/include/thrust/system/cuda/detail/equal.h",
+ "cuda/include/thrust/system/cuda/detail/error.inl",
+ "cuda/include/thrust/system/cuda/detail/execution_policy.h",
+ "cuda/include/thrust/system/cuda/detail/extrema.h",
+ "cuda/include/thrust/system/cuda/detail/fill.h",
+ "cuda/include/thrust/system/cuda/detail/find.h",
+ "cuda/include/thrust/system/cuda/detail/for_each.h",
+ "cuda/include/thrust/system/cuda/detail/gather.h",
+ "cuda/include/thrust/system/cuda/detail/generate.h",
+ "cuda/include/thrust/system/cuda/detail/get_value.h",
+ "cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h",
+ "cuda/include/thrust/system/cuda/detail/guarded_driver_types.h",
+ "cuda/include/thrust/system/cuda/detail/inner_product.h",
+ "cuda/include/thrust/system/cuda/detail/internal/copy_cross_system.h",
+ "cuda/include/thrust/system/cuda/detail/internal/copy_device_to_device.h",
+ "cuda/include/thrust/system/cuda/detail/iter_swap.h",
+ "cuda/include/thrust/system/cuda/detail/logical.h",
+ "cuda/include/thrust/system/cuda/detail/malloc_and_free.h",
+ "cuda/include/thrust/system/cuda/detail/memory.inl",
+ "cuda/include/thrust/system/cuda/detail/merge.h",
+ "cuda/include/thrust/system/cuda/detail/mismatch.h",
+ "cuda/include/thrust/system/cuda/detail/par.h",
+ "cuda/include/thrust/system/cuda/detail/par_to_seq.h",
+ "cuda/include/thrust/system/cuda/detail/parallel_for.h",
+ "cuda/include/thrust/system/cuda/detail/partition.h",
+ "cuda/include/thrust/system/cuda/detail/reduce.h",
+ "cuda/include/thrust/system/cuda/detail/reduce_by_key.h",
+ "cuda/include/thrust/system/cuda/detail/remove.h",
+ "cuda/include/thrust/system/cuda/detail/replace.h",
+ "cuda/include/thrust/system/cuda/detail/reverse.h",
+ "cuda/include/thrust/system/cuda/detail/scan.h",
+ "cuda/include/thrust/system/cuda/detail/scan_by_key.h",
+ "cuda/include/thrust/system/cuda/detail/scatter.h",
+ "cuda/include/thrust/system/cuda/detail/sequence.h",
+ "cuda/include/thrust/system/cuda/detail/set_operations.h",
+ "cuda/include/thrust/system/cuda/detail/sort.h",
+ "cuda/include/thrust/system/cuda/detail/swap_ranges.h",
+ "cuda/include/thrust/system/cuda/detail/tabulate.h",
+ "cuda/include/thrust/system/cuda/detail/temporary_buffer.h",
+ "cuda/include/thrust/system/cuda/detail/terminate.h",
+ "cuda/include/thrust/system/cuda/detail/transform.h",
+ "cuda/include/thrust/system/cuda/detail/transform_reduce.h",
+ "cuda/include/thrust/system/cuda/detail/transform_scan.h",
+ "cuda/include/thrust/system/cuda/detail/uninitialized_copy.h",
+ "cuda/include/thrust/system/cuda/detail/uninitialized_fill.h",
+ "cuda/include/thrust/system/cuda/detail/unique.h",
+ "cuda/include/thrust/system/cuda/detail/unique_by_key.h",
+ "cuda/include/thrust/system/cuda/detail/util.h",
+ "cuda/include/thrust/system/cuda/detail/vector.inl",
+ "cuda/include/thrust/system/cuda/error.h",
+ "cuda/include/thrust/system/cuda/execution_policy.h",
+ "cuda/include/thrust/system/cuda/experimental/pinned_allocator.h",
+ "cuda/include/thrust/system/cuda/memory.h",
+ "cuda/include/thrust/system/cuda/vector.h",
+ "cuda/include/thrust/system/detail/adl/adjacent_difference.h",
+ "cuda/include/thrust/system/detail/adl/assign_value.h",
+ "cuda/include/thrust/system/detail/adl/binary_search.h",
+ "cuda/include/thrust/system/detail/adl/copy.h",
+ "cuda/include/thrust/system/detail/adl/copy_if.h",
+ "cuda/include/thrust/system/detail/adl/count.h",
+ "cuda/include/thrust/system/detail/adl/equal.h",
+ "cuda/include/thrust/system/detail/adl/extrema.h",
+ "cuda/include/thrust/system/detail/adl/fill.h",
+ "cuda/include/thrust/system/detail/adl/find.h",
+ "cuda/include/thrust/system/detail/adl/for_each.h",
+ "cuda/include/thrust/system/detail/adl/gather.h",
+ "cuda/include/thrust/system/detail/adl/generate.h",
+ "cuda/include/thrust/system/detail/adl/get_value.h",
+ "cuda/include/thrust/system/detail/adl/inner_product.h",
+ "cuda/include/thrust/system/detail/adl/iter_swap.h",
+ "cuda/include/thrust/system/detail/adl/logical.h",
+ "cuda/include/thrust/system/detail/adl/malloc_and_free.h",
+ "cuda/include/thrust/system/detail/adl/merge.h",
+ "cuda/include/thrust/system/detail/adl/mismatch.h",
+ "cuda/include/thrust/system/detail/adl/partition.h",
+ "cuda/include/thrust/system/detail/adl/reduce.h",
+ "cuda/include/thrust/system/detail/adl/reduce_by_key.h",
+ "cuda/include/thrust/system/detail/adl/remove.h",
+ "cuda/include/thrust/system/detail/adl/replace.h",
+ "cuda/include/thrust/system/detail/adl/reverse.h",
+ "cuda/include/thrust/system/detail/adl/scan.h",
+ "cuda/include/thrust/system/detail/adl/scan_by_key.h",
+ "cuda/include/thrust/system/detail/adl/scatter.h",
+ "cuda/include/thrust/system/detail/adl/sequence.h",
+ "cuda/include/thrust/system/detail/adl/set_operations.h",
+ "cuda/include/thrust/system/detail/adl/sort.h",
+ "cuda/include/thrust/system/detail/adl/swap_ranges.h",
+ "cuda/include/thrust/system/detail/adl/tabulate.h",
+ "cuda/include/thrust/system/detail/adl/temporary_buffer.h",
+ "cuda/include/thrust/system/detail/adl/transform.h",
+ "cuda/include/thrust/system/detail/adl/transform_reduce.h",
+ "cuda/include/thrust/system/detail/adl/transform_scan.h",
+ "cuda/include/thrust/system/detail/adl/uninitialized_copy.h",
+ "cuda/include/thrust/system/detail/adl/uninitialized_fill.h",
+ "cuda/include/thrust/system/detail/adl/unique.h",
+ "cuda/include/thrust/system/detail/adl/unique_by_key.h",
+ "cuda/include/thrust/system/detail/bad_alloc.h",
+ "cuda/include/thrust/system/detail/errno.h",
+ "cuda/include/thrust/system/detail/error_category.inl",
+ "cuda/include/thrust/system/detail/error_code.inl",
+ "cuda/include/thrust/system/detail/error_condition.inl",
+ "cuda/include/thrust/system/detail/generic/adjacent_difference.h",
+ "cuda/include/thrust/system/detail/generic/adjacent_difference.inl",
+ "cuda/include/thrust/system/detail/generic/advance.h",
+ "cuda/include/thrust/system/detail/generic/advance.inl",
+ "cuda/include/thrust/system/detail/generic/binary_search.h",
+ "cuda/include/thrust/system/detail/generic/binary_search.inl",
+ "cuda/include/thrust/system/detail/generic/copy.h",
+ "cuda/include/thrust/system/detail/generic/copy.inl",
+ "cuda/include/thrust/system/detail/generic/copy_if.h",
+ "cuda/include/thrust/system/detail/generic/copy_if.inl",
+ "cuda/include/thrust/system/detail/generic/count.h",
+ "cuda/include/thrust/system/detail/generic/count.inl",
+ "cuda/include/thrust/system/detail/generic/distance.h",
+ "cuda/include/thrust/system/detail/generic/distance.inl",
+ "cuda/include/thrust/system/detail/generic/equal.h",
+ "cuda/include/thrust/system/detail/generic/equal.inl",
+ "cuda/include/thrust/system/detail/generic/extrema.h",
+ "cuda/include/thrust/system/detail/generic/extrema.inl",
+ "cuda/include/thrust/system/detail/generic/fill.h",
+ "cuda/include/thrust/system/detail/generic/find.h",
+ "cuda/include/thrust/system/detail/generic/find.inl",
+ "cuda/include/thrust/system/detail/generic/for_each.h",
+ "cuda/include/thrust/system/detail/generic/gather.h",
+ "cuda/include/thrust/system/detail/generic/gather.inl",
+ "cuda/include/thrust/system/detail/generic/generate.h",
+ "cuda/include/thrust/system/detail/generic/generate.inl",
+ "cuda/include/thrust/system/detail/generic/inner_product.h",
+ "cuda/include/thrust/system/detail/generic/inner_product.inl",
+ "cuda/include/thrust/system/detail/generic/logical.h",
+ "cuda/include/thrust/system/detail/generic/memory.h",
+ "cuda/include/thrust/system/detail/generic/memory.inl",
+ "cuda/include/thrust/system/detail/generic/merge.h",
+ "cuda/include/thrust/system/detail/generic/merge.inl",
+ "cuda/include/thrust/system/detail/generic/mismatch.h",
+ "cuda/include/thrust/system/detail/generic/mismatch.inl",
+ "cuda/include/thrust/system/detail/generic/partition.h",
+ "cuda/include/thrust/system/detail/generic/partition.inl",
+ "cuda/include/thrust/system/detail/generic/reduce.h",
+ "cuda/include/thrust/system/detail/generic/reduce.inl",
+ "cuda/include/thrust/system/detail/generic/reduce_by_key.h",
+ "cuda/include/thrust/system/detail/generic/reduce_by_key.inl",
+ "cuda/include/thrust/system/detail/generic/remove.h",
+ "cuda/include/thrust/system/detail/generic/remove.inl",
+ "cuda/include/thrust/system/detail/generic/replace.h",
+ "cuda/include/thrust/system/detail/generic/replace.inl",
+ "cuda/include/thrust/system/detail/generic/reverse.h",
+ "cuda/include/thrust/system/detail/generic/reverse.inl",
+ "cuda/include/thrust/system/detail/generic/scalar/binary_search.h",
+ "cuda/include/thrust/system/detail/generic/scalar/binary_search.inl",
+ "cuda/include/thrust/system/detail/generic/scan.h",
+ "cuda/include/thrust/system/detail/generic/scan.inl",
+ "cuda/include/thrust/system/detail/generic/scan_by_key.h",
+ "cuda/include/thrust/system/detail/generic/scan_by_key.inl",
+ "cuda/include/thrust/system/detail/generic/scatter.h",
+ "cuda/include/thrust/system/detail/generic/scatter.inl",
+ "cuda/include/thrust/system/detail/generic/select_system.h",
+ "cuda/include/thrust/system/detail/generic/sequence.h",
+ "cuda/include/thrust/system/detail/generic/sequence.inl",
+ "cuda/include/thrust/system/detail/generic/set_operations.h",
+ "cuda/include/thrust/system/detail/generic/set_operations.inl",
+ "cuda/include/thrust/system/detail/generic/sort.h",
+ "cuda/include/thrust/system/detail/generic/sort.inl",
+ "cuda/include/thrust/system/detail/generic/swap_ranges.h",
+ "cuda/include/thrust/system/detail/generic/swap_ranges.inl",
+ "cuda/include/thrust/system/detail/generic/tabulate.h",
+ "cuda/include/thrust/system/detail/generic/tabulate.inl",
+ "cuda/include/thrust/system/detail/generic/tag.h",
+ "cuda/include/thrust/system/detail/generic/temporary_buffer.h",
+ "cuda/include/thrust/system/detail/generic/temporary_buffer.inl",
+ "cuda/include/thrust/system/detail/generic/transform.h",
+ "cuda/include/thrust/system/detail/generic/transform.inl",
+ "cuda/include/thrust/system/detail/generic/transform_reduce.h",
+ "cuda/include/thrust/system/detail/generic/transform_reduce.inl",
+ "cuda/include/thrust/system/detail/generic/transform_scan.h",
+ "cuda/include/thrust/system/detail/generic/transform_scan.inl",
+ "cuda/include/thrust/system/detail/generic/type_traits.h",
+ "cuda/include/thrust/system/detail/generic/uninitialized_copy.h",
+ "cuda/include/thrust/system/detail/generic/uninitialized_copy.inl",
+ "cuda/include/thrust/system/detail/generic/uninitialized_fill.h",
+ "cuda/include/thrust/system/detail/generic/uninitialized_fill.inl",
+ "cuda/include/thrust/system/detail/generic/unique.h",
+ "cuda/include/thrust/system/detail/generic/unique.inl",
+ "cuda/include/thrust/system/detail/generic/unique_by_key.h",
+ "cuda/include/thrust/system/detail/generic/unique_by_key.inl",
+ "cuda/include/thrust/system/detail/internal/decompose.h",
+ "cuda/include/thrust/system/detail/sequential/adjacent_difference.h",
+ "cuda/include/thrust/system/detail/sequential/assign_value.h",
+ "cuda/include/thrust/system/detail/sequential/binary_search.h",
+ "cuda/include/thrust/system/detail/sequential/copy.h",
+ "cuda/include/thrust/system/detail/sequential/copy.inl",
+ "cuda/include/thrust/system/detail/sequential/copy_backward.h",
+ "cuda/include/thrust/system/detail/sequential/copy_if.h",
+ "cuda/include/thrust/system/detail/sequential/count.h",
+ "cuda/include/thrust/system/detail/sequential/equal.h",
+ "cuda/include/thrust/system/detail/sequential/execution_policy.h",
+ "cuda/include/thrust/system/detail/sequential/extrema.h",
+ "cuda/include/thrust/system/detail/sequential/fill.h",
+ "cuda/include/thrust/system/detail/sequential/find.h",
+ "cuda/include/thrust/system/detail/sequential/for_each.h",
+ "cuda/include/thrust/system/detail/sequential/gather.h",
+ "cuda/include/thrust/system/detail/sequential/general_copy.h",
+ "cuda/include/thrust/system/detail/sequential/generate.h",
+ "cuda/include/thrust/system/detail/sequential/get_value.h",
+ "cuda/include/thrust/system/detail/sequential/inner_product.h",
+ "cuda/include/thrust/system/detail/sequential/insertion_sort.h",
+ "cuda/include/thrust/system/detail/sequential/iter_swap.h",
+ "cuda/include/thrust/system/detail/sequential/logical.h",
+ "cuda/include/thrust/system/detail/sequential/malloc_and_free.h",
+ "cuda/include/thrust/system/detail/sequential/merge.h",
+ "cuda/include/thrust/system/detail/sequential/merge.inl",
+ "cuda/include/thrust/system/detail/sequential/mismatch.h",
+ "cuda/include/thrust/system/detail/sequential/partition.h",
+ "cuda/include/thrust/system/detail/sequential/reduce.h",
+ "cuda/include/thrust/system/detail/sequential/reduce_by_key.h",
+ "cuda/include/thrust/system/detail/sequential/remove.h",
+ "cuda/include/thrust/system/detail/sequential/replace.h",
+ "cuda/include/thrust/system/detail/sequential/reverse.h",
+ "cuda/include/thrust/system/detail/sequential/scan.h",
+ "cuda/include/thrust/system/detail/sequential/scan_by_key.h",
+ "cuda/include/thrust/system/detail/sequential/scatter.h",
+ "cuda/include/thrust/system/detail/sequential/sequence.h",
+ "cuda/include/thrust/system/detail/sequential/set_operations.h",
+ "cuda/include/thrust/system/detail/sequential/sort.h",
+ "cuda/include/thrust/system/detail/sequential/sort.inl",
+ "cuda/include/thrust/system/detail/sequential/stable_merge_sort.h",
+ "cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl",
+ "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h",
+ "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl",
+ "cuda/include/thrust/system/detail/sequential/stable_radix_sort.h",
+ "cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl",
+ "cuda/include/thrust/system/detail/sequential/swap_ranges.h",
+ "cuda/include/thrust/system/detail/sequential/tabulate.h",
+ "cuda/include/thrust/system/detail/sequential/temporary_buffer.h",
+ "cuda/include/thrust/system/detail/sequential/transform.h",
+ "cuda/include/thrust/system/detail/sequential/transform_reduce.h",
+ "cuda/include/thrust/system/detail/sequential/transform_scan.h",
+ "cuda/include/thrust/system/detail/sequential/trivial_copy.h",
+ "cuda/include/thrust/system/detail/sequential/uninitialized_copy.h",
+ "cuda/include/thrust/system/detail/sequential/uninitialized_fill.h",
+ "cuda/include/thrust/system/detail/sequential/unique.h",
+ "cuda/include/thrust/system/detail/sequential/unique_by_key.h",
+ "cuda/include/thrust/system/detail/system_error.inl",
+ "cuda/include/thrust/system/error_code.h",
+ "cuda/include/thrust/system/omp/detail/adjacent_difference.h",
+ "cuda/include/thrust/system/omp/detail/assign_value.h",
+ "cuda/include/thrust/system/omp/detail/binary_search.h",
+ "cuda/include/thrust/system/omp/detail/copy.h",
+ "cuda/include/thrust/system/omp/detail/copy.inl",
+ "cuda/include/thrust/system/omp/detail/copy_if.h",
+ "cuda/include/thrust/system/omp/detail/copy_if.inl",
+ "cuda/include/thrust/system/omp/detail/count.h",
+ "cuda/include/thrust/system/omp/detail/default_decomposition.h",
+ "cuda/include/thrust/system/omp/detail/default_decomposition.inl",
+ "cuda/include/thrust/system/omp/detail/equal.h",
+ "cuda/include/thrust/system/omp/detail/execution_policy.h",
+ "cuda/include/thrust/system/omp/detail/extrema.h",
+ "cuda/include/thrust/system/omp/detail/fill.h",
+ "cuda/include/thrust/system/omp/detail/find.h",
+ "cuda/include/thrust/system/omp/detail/for_each.h",
+ "cuda/include/thrust/system/omp/detail/for_each.inl",
+ "cuda/include/thrust/system/omp/detail/gather.h",
+ "cuda/include/thrust/system/omp/detail/generate.h",
+ "cuda/include/thrust/system/omp/detail/get_value.h",
+ "cuda/include/thrust/system/omp/detail/inner_product.h",
+ "cuda/include/thrust/system/omp/detail/iter_swap.h",
+ "cuda/include/thrust/system/omp/detail/logical.h",
+ "cuda/include/thrust/system/omp/detail/malloc_and_free.h",
+ "cuda/include/thrust/system/omp/detail/memory.inl",
+ "cuda/include/thrust/system/omp/detail/merge.h",
+ "cuda/include/thrust/system/omp/detail/mismatch.h",
+ "cuda/include/thrust/system/omp/detail/par.h",
+ "cuda/include/thrust/system/omp/detail/partition.h",
+ "cuda/include/thrust/system/omp/detail/partition.inl",
+ "cuda/include/thrust/system/omp/detail/reduce.h",
+ "cuda/include/thrust/system/omp/detail/reduce.inl",
+ "cuda/include/thrust/system/omp/detail/reduce_by_key.h",
+ "cuda/include/thrust/system/omp/detail/reduce_by_key.inl",
+ "cuda/include/thrust/system/omp/detail/reduce_intervals.h",
+ "cuda/include/thrust/system/omp/detail/reduce_intervals.inl",
+ "cuda/include/thrust/system/omp/detail/remove.h",
+ "cuda/include/thrust/system/omp/detail/remove.inl",
+ "cuda/include/thrust/system/omp/detail/replace.h",
+ "cuda/include/thrust/system/omp/detail/reverse.h",
+ "cuda/include/thrust/system/omp/detail/scan.h",
+ "cuda/include/thrust/system/omp/detail/scan_by_key.h",
+ "cuda/include/thrust/system/omp/detail/scatter.h",
+ "cuda/include/thrust/system/omp/detail/sequence.h",
+ "cuda/include/thrust/system/omp/detail/set_operations.h",
+ "cuda/include/thrust/system/omp/detail/sort.h",
+ "cuda/include/thrust/system/omp/detail/sort.inl",
+ "cuda/include/thrust/system/omp/detail/swap_ranges.h",
+ "cuda/include/thrust/system/omp/detail/tabulate.h",
+ "cuda/include/thrust/system/omp/detail/temporary_buffer.h",
+ "cuda/include/thrust/system/omp/detail/transform.h",
+ "cuda/include/thrust/system/omp/detail/transform_reduce.h",
+ "cuda/include/thrust/system/omp/detail/transform_scan.h",
+ "cuda/include/thrust/system/omp/detail/uninitialized_copy.h",
+ "cuda/include/thrust/system/omp/detail/uninitialized_fill.h",
+ "cuda/include/thrust/system/omp/detail/unique.h",
+ "cuda/include/thrust/system/omp/detail/unique.inl",
+ "cuda/include/thrust/system/omp/detail/unique_by_key.h",
+ "cuda/include/thrust/system/omp/detail/unique_by_key.inl",
+ "cuda/include/thrust/system/omp/detail/vector.inl",
+ "cuda/include/thrust/system/omp/execution_policy.h",
+ "cuda/include/thrust/system/omp/memory.h",
+ "cuda/include/thrust/system/omp/vector.h",
+ "cuda/include/thrust/system/system_error.h",
+ "cuda/include/thrust/system/tbb/detail/adjacent_difference.h",
+ "cuda/include/thrust/system/tbb/detail/assign_value.h",
+ "cuda/include/thrust/system/tbb/detail/binary_search.h",
+ "cuda/include/thrust/system/tbb/detail/copy.h",
+ "cuda/include/thrust/system/tbb/detail/copy.inl",
+ "cuda/include/thrust/system/tbb/detail/copy_if.h",
+ "cuda/include/thrust/system/tbb/detail/copy_if.inl",
+ "cuda/include/thrust/system/tbb/detail/count.h",
+ "cuda/include/thrust/system/tbb/detail/equal.h",
+ "cuda/include/thrust/system/tbb/detail/execution_policy.h",
+ "cuda/include/thrust/system/tbb/detail/extrema.h",
+ "cuda/include/thrust/system/tbb/detail/fill.h",
+ "cuda/include/thrust/system/tbb/detail/find.h",
+ "cuda/include/thrust/system/tbb/detail/for_each.h",
+ "cuda/include/thrust/system/tbb/detail/for_each.inl",
+ "cuda/include/thrust/system/tbb/detail/gather.h",
+ "cuda/include/thrust/system/tbb/detail/generate.h",
+ "cuda/include/thrust/system/tbb/detail/get_value.h",
+ "cuda/include/thrust/system/tbb/detail/inner_product.h",
+ "cuda/include/thrust/system/tbb/detail/iter_swap.h",
+ "cuda/include/thrust/system/tbb/detail/logical.h",
+ "cuda/include/thrust/system/tbb/detail/malloc_and_free.h",
+ "cuda/include/thrust/system/tbb/detail/memory.inl",
+ "cuda/include/thrust/system/tbb/detail/merge.h",
+ "cuda/include/thrust/system/tbb/detail/merge.inl",
+ "cuda/include/thrust/system/tbb/detail/mismatch.h",
+ "cuda/include/thrust/system/tbb/detail/par.h",
+ "cuda/include/thrust/system/tbb/detail/partition.h",
+ "cuda/include/thrust/system/tbb/detail/partition.inl",
+ "cuda/include/thrust/system/tbb/detail/reduce.h",
+ "cuda/include/thrust/system/tbb/detail/reduce.inl",
+ "cuda/include/thrust/system/tbb/detail/reduce_by_key.h",
+ "cuda/include/thrust/system/tbb/detail/reduce_by_key.inl",
+ "cuda/include/thrust/system/tbb/detail/reduce_intervals.h",
+ "cuda/include/thrust/system/tbb/detail/remove.h",
+ "cuda/include/thrust/system/tbb/detail/remove.inl",
+ "cuda/include/thrust/system/tbb/detail/replace.h",
+ "cuda/include/thrust/system/tbb/detail/reverse.h",
+ "cuda/include/thrust/system/tbb/detail/scan.h",
+ "cuda/include/thrust/system/tbb/detail/scan.inl",
+ "cuda/include/thrust/system/tbb/detail/scan_by_key.h",
+ "cuda/include/thrust/system/tbb/detail/scatter.h",
+ "cuda/include/thrust/system/tbb/detail/sequence.h",
+ "cuda/include/thrust/system/tbb/detail/set_operations.h",
+ "cuda/include/thrust/system/tbb/detail/sort.h",
+ "cuda/include/thrust/system/tbb/detail/sort.inl",
+ "cuda/include/thrust/system/tbb/detail/swap_ranges.h",
+ "cuda/include/thrust/system/tbb/detail/tabulate.h",
+ "cuda/include/thrust/system/tbb/detail/temporary_buffer.h",
+ "cuda/include/thrust/system/tbb/detail/transform.h",
+ "cuda/include/thrust/system/tbb/detail/transform_reduce.h",
+ "cuda/include/thrust/system/tbb/detail/transform_scan.h",
+ "cuda/include/thrust/system/tbb/detail/uninitialized_copy.h",
+ "cuda/include/thrust/system/tbb/detail/uninitialized_fill.h",
+ "cuda/include/thrust/system/tbb/detail/unique.h",
+ "cuda/include/thrust/system/tbb/detail/unique.inl",
+ "cuda/include/thrust/system/tbb/detail/unique_by_key.h",
+ "cuda/include/thrust/system/tbb/detail/unique_by_key.inl",
+ "cuda/include/thrust/system/tbb/detail/vector.inl",
+ "cuda/include/thrust/system/tbb/execution_policy.h",
+ "cuda/include/thrust/system/tbb/memory.h",
+ "cuda/include/thrust/system/tbb/vector.h",
+ "cuda/include/thrust/system_error.h",
+ "cuda/include/thrust/tabulate.h",
+ "cuda/include/thrust/transform.h",
+ "cuda/include/thrust/transform_reduce.h",
+ "cuda/include/thrust/transform_scan.h",
+ "cuda/include/thrust/tuple.h",
+ "cuda/include/thrust/uninitialized_copy.h",
+ "cuda/include/thrust/uninitialized_fill.h",
+ "cuda/include/thrust/unique.h",
+ "cuda/include/thrust/version.h",
+ "cuda/include/vector_functions.h",
+ "cuda/include/vector_functions.hpp",
+ "cuda/include/vector_types.h",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp -f "/usr/local/cuda-10.0/include/CL/cl.h" "$(@D)/cuda/include/CL/cl.h" && cp -f "/usr/local/cuda-10.0/include/CL/cl.hpp" "$(@D)/cuda/include/CL/cl.hpp" && cp -f "/usr/local/cuda-10.0/include/CL/cl_egl.h" "$(@D)/cuda/include/CL/cl_egl.h" && cp -f "/usr/local/cuda-10.0/include/CL/cl_ext.h" "$(@D)/cuda/include/CL/cl_ext.h" && cp -f "/usr/local/cuda-10.0/include/CL/cl_gl.h" "$(@D)/cuda/include/CL/cl_gl.h" && cp -f "/usr/local/cuda-10.0/include/CL/cl_gl_ext.h" "$(@D)/cuda/include/CL/cl_gl_ext.h" && cp -f "/usr/local/cuda-10.0/include/CL/cl_platform.h" "$(@D)/cuda/include/CL/cl_platform.h" && cp -f "/usr/local/cuda-10.0/include/CL/opencl.h" "$(@D)/cuda/include/CL/opencl.h" && cp -f "/usr/local/cuda-10.0/include/builtin_types.h" "$(@D)/cuda/include/builtin_types.h" && cp -f "/usr/local/cuda-10.0/include/channel_descriptor.h" "$(@D)/cuda/include/channel_descriptor.h" && cp -f "/usr/local/cuda-10.0/include/common_functions.h" "$(@D)/cuda/include/common_functions.h" && cp -f "/usr/local/cuda-10.0/include/cooperative_groups.h" "$(@D)/cuda/include/cooperative_groups.h" && cp -f "/usr/local/cuda-10.0/include/cooperative_groups_helpers.h" "$(@D)/cuda/include/cooperative_groups_helpers.h" && cp -f "/usr/local/cuda-10.0/include/crt/common_functions.h" "$(@D)/cuda/include/crt/common_functions.h" && cp -f "/usr/local/cuda-10.0/include/crt/device_double_functions.h" "$(@D)/cuda/include/crt/device_double_functions.h" && cp -f "/usr/local/cuda-10.0/include/crt/device_double_functions.hpp" "$(@D)/cuda/include/crt/device_double_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/crt/device_functions.h" "$(@D)/cuda/include/crt/device_functions.h" && cp -f "/usr/local/cuda-10.0/include/crt/device_functions.hpp" "$(@D)/cuda/include/crt/device_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/crt/func_macro.h" "$(@D)/cuda/include/crt/func_macro.h" && cp -f "/usr/local/cuda-10.0/include/crt/host_config.h" "$(@D)/cuda/include/crt/host_config.h" && cp -f "/usr/local/cuda-10.0/include/crt/host_defines.h" "$(@D)/cuda/include/crt/host_defines.h" && cp -f "/usr/local/cuda-10.0/include/crt/host_runtime.h" "$(@D)/cuda/include/crt/host_runtime.h" && cp -f "/usr/local/cuda-10.0/include/crt/math_functions.h" "$(@D)/cuda/include/crt/math_functions.h" && cp -f "/usr/local/cuda-10.0/include/crt/math_functions.hpp" "$(@D)/cuda/include/crt/math_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/crt/mma.h" "$(@D)/cuda/include/crt/mma.h" && cp -f "/usr/local/cuda-10.0/include/crt/mma.hpp" "$(@D)/cuda/include/crt/mma.hpp" && cp -f "/usr/local/cuda-10.0/include/crt/nvfunctional" "$(@D)/cuda/include/crt/nvfunctional" && cp -f "/usr/local/cuda-10.0/include/crt/sm_70_rt.h" "$(@D)/cuda/include/crt/sm_70_rt.h" && cp -f "/usr/local/cuda-10.0/include/crt/sm_70_rt.hpp" "$(@D)/cuda/include/crt/sm_70_rt.hpp" && cp -f "/usr/local/cuda-10.0/include/crt/storage_class.h" "$(@D)/cuda/include/crt/storage_class.h" && cp -f "/usr/local/cuda-10.0/include/cuComplex.h" "$(@D)/cuda/include/cuComplex.h" && cp -f "/usr/local/cuda-10.0/include/cublas.h" "$(@D)/cuda/include/cublas.h" && cp -f "/usr/local/cuda-10.0/include/cublasXt.h" "$(@D)/cuda/include/cublasXt.h" && cp -f "/usr/local/cuda-10.0/include/cublas_api.h" "$(@D)/cuda/include/cublas_api.h" && cp -f "/usr/local/cuda-10.0/include/cublas_v2.h" "$(@D)/cuda/include/cublas_v2.h" && cp -f "/usr/local/cuda-10.0/include/cuda.h" "$(@D)/cuda/include/cuda.h" && cp -f "/usr/local/cuda-10.0/include/cudaEGL.h" "$(@D)/cuda/include/cudaEGL.h" && cp -f "/usr/local/cuda-10.0/include/cudaGL.h" "$(@D)/cuda/include/cudaGL.h" && cp -f "/usr/local/cuda-10.0/include/cudaProfiler.h" "$(@D)/cuda/include/cudaProfiler.h" && cp -f "/usr/local/cuda-10.0/include/cudaVDPAU.h" "$(@D)/cuda/include/cudaVDPAU.h" && cp -f "/usr/local/cuda-10.0/include/cuda_device_runtime_api.h" "$(@D)/cuda/include/cuda_device_runtime_api.h" && cp -f "/usr/local/cuda-10.0/include/cuda_egl_interop.h" "$(@D)/cuda/include/cuda_egl_interop.h" && cp -f "/usr/local/cuda-10.0/include/cuda_fp16.h" "$(@D)/cuda/include/cuda_fp16.h" && cp -f "/usr/local/cuda-10.0/include/cuda_fp16.hpp" "$(@D)/cuda/include/cuda_fp16.hpp" && cp -f "/usr/local/cuda-10.0/include/cuda_gl_interop.h" "$(@D)/cuda/include/cuda_gl_interop.h" && cp -f "/usr/local/cuda-10.0/include/cuda_occupancy.h" "$(@D)/cuda/include/cuda_occupancy.h" && cp -f "/usr/local/cuda-10.0/include/cuda_profiler_api.h" "$(@D)/cuda/include/cuda_profiler_api.h" && cp -f "/usr/local/cuda-10.0/include/cuda_runtime.h" "$(@D)/cuda/include/cuda_runtime.h" && cp -f "/usr/local/cuda-10.0/include/cuda_runtime_api.h" "$(@D)/cuda/include/cuda_runtime_api.h" && cp -f "/usr/local/cuda-10.0/include/cuda_surface_types.h" "$(@D)/cuda/include/cuda_surface_types.h" && cp -f "/usr/local/cuda-10.0/include/cuda_texture_types.h" "$(@D)/cuda/include/cuda_texture_types.h" && cp -f "/usr/local/cuda-10.0/include/cuda_vdpau_interop.h" "$(@D)/cuda/include/cuda_vdpau_interop.h" && cp -f "/usr/local/cuda-10.0/include/cudalibxt.h" "$(@D)/cuda/include/cudalibxt.h" && cp -f "/usr/local/cuda-10.0/include/cudart_platform.h" "$(@D)/cuda/include/cudart_platform.h" && cp -f "/usr/local/cuda-10.0/include/cufft.h" "$(@D)/cuda/include/cufft.h" && cp -f "/usr/local/cuda-10.0/include/cufftXt.h" "$(@D)/cuda/include/cufftXt.h" && cp -f "/usr/local/cuda-10.0/include/cufftw.h" "$(@D)/cuda/include/cufftw.h" && cp -f "/usr/local/cuda-10.0/include/curand.h" "$(@D)/cuda/include/curand.h" && cp -f "/usr/local/cuda-10.0/include/curand_discrete.h" "$(@D)/cuda/include/curand_discrete.h" && cp -f "/usr/local/cuda-10.0/include/curand_discrete2.h" "$(@D)/cuda/include/curand_discrete2.h" && cp -f "/usr/local/cuda-10.0/include/curand_globals.h" "$(@D)/cuda/include/curand_globals.h" && cp -f "/usr/local/cuda-10.0/include/curand_kernel.h" "$(@D)/cuda/include/curand_kernel.h" && cp -f "/usr/local/cuda-10.0/include/curand_lognormal.h" "$(@D)/cuda/include/curand_lognormal.h" && cp -f "/usr/local/cuda-10.0/include/curand_mrg32k3a.h" "$(@D)/cuda/include/curand_mrg32k3a.h" && cp -f "/usr/local/cuda-10.0/include/curand_mtgp32.h" "$(@D)/cuda/include/curand_mtgp32.h" && cp -f "/usr/local/cuda-10.0/include/curand_mtgp32_host.h" "$(@D)/cuda/include/curand_mtgp32_host.h" && cp -f "/usr/local/cuda-10.0/include/curand_mtgp32_kernel.h" "$(@D)/cuda/include/curand_mtgp32_kernel.h" && cp -f "/usr/local/cuda-10.0/include/curand_mtgp32dc_p_11213.h" "$(@D)/cuda/include/curand_mtgp32dc_p_11213.h" && cp -f "/usr/local/cuda-10.0/include/curand_normal.h" "$(@D)/cuda/include/curand_normal.h" && cp -f "/usr/local/cuda-10.0/include/curand_normal_static.h" "$(@D)/cuda/include/curand_normal_static.h" && cp -f "/usr/local/cuda-10.0/include/curand_philox4x32_x.h" "$(@D)/cuda/include/curand_philox4x32_x.h" && cp -f "/usr/local/cuda-10.0/include/curand_poisson.h" "$(@D)/cuda/include/curand_poisson.h" && cp -f "/usr/local/cuda-10.0/include/curand_precalc.h" "$(@D)/cuda/include/curand_precalc.h" && cp -f "/usr/local/cuda-10.0/include/curand_uniform.h" "$(@D)/cuda/include/curand_uniform.h" && cp -f "/usr/local/cuda-10.0/include/cusolverDn.h" "$(@D)/cuda/include/cusolverDn.h" && cp -f "/usr/local/cuda-10.0/include/cusolverRf.h" "$(@D)/cuda/include/cusolverRf.h" && cp -f "/usr/local/cuda-10.0/include/cusolverSp.h" "$(@D)/cuda/include/cusolverSp.h" && cp -f "/usr/local/cuda-10.0/include/cusolverSp_LOWLEVEL_PREVIEW.h" "$(@D)/cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h" && cp -f "/usr/local/cuda-10.0/include/cusolver_common.h" "$(@D)/cuda/include/cusolver_common.h" && cp -f "/usr/local/cuda-10.0/include/cusparse.h" "$(@D)/cuda/include/cusparse.h" && cp -f "/usr/local/cuda-10.0/include/cusparse_v2.h" "$(@D)/cuda/include/cusparse_v2.h" && cp -f "/usr/local/cuda-10.0/include/device_atomic_functions.h" "$(@D)/cuda/include/device_atomic_functions.h" && cp -f "/usr/local/cuda-10.0/include/device_atomic_functions.hpp" "$(@D)/cuda/include/device_atomic_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/device_double_functions.h" "$(@D)/cuda/include/device_double_functions.h" && cp -f "/usr/local/cuda-10.0/include/device_functions.h" "$(@D)/cuda/include/device_functions.h" && cp -f "/usr/local/cuda-10.0/include/device_launch_parameters.h" "$(@D)/cuda/include/device_launch_parameters.h" && cp -f "/usr/local/cuda-10.0/include/device_types.h" "$(@D)/cuda/include/device_types.h" && cp -f "/usr/local/cuda-10.0/include/driver_functions.h" "$(@D)/cuda/include/driver_functions.h" && cp -f "/usr/local/cuda-10.0/include/driver_types.h" "$(@D)/cuda/include/driver_types.h" && cp -f "/usr/local/cuda-10.0/include/fatBinaryCtl.h" "$(@D)/cuda/include/fatBinaryCtl.h" && cp -f "/usr/local/cuda-10.0/include/fatbinary.h" "$(@D)/cuda/include/fatbinary.h" && cp -f "/usr/local/cuda-10.0/include/host_config.h" "$(@D)/cuda/include/host_config.h" && cp -f "/usr/local/cuda-10.0/include/host_defines.h" "$(@D)/cuda/include/host_defines.h" && cp -f "/usr/local/cuda-10.0/include/library_types.h" "$(@D)/cuda/include/library_types.h" && cp -f "/usr/local/cuda-10.0/include/math_constants.h" "$(@D)/cuda/include/math_constants.h" && cp -f "/usr/local/cuda-10.0/include/math_functions.h" "$(@D)/cuda/include/math_functions.h" && cp -f "/usr/local/cuda-10.0/include/mma.h" "$(@D)/cuda/include/mma.h" && cp -f "/usr/local/cuda-10.0/include/npp.h" "$(@D)/cuda/include/npp.h" && cp -f "/usr/local/cuda-10.0/include/nppcore.h" "$(@D)/cuda/include/nppcore.h" && cp -f "/usr/local/cuda-10.0/include/nppdefs.h" "$(@D)/cuda/include/nppdefs.h" && cp -f "/usr/local/cuda-10.0/include/nppi.h" "$(@D)/cuda/include/nppi.h" && cp -f "/usr/local/cuda-10.0/include/nppi_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/nppi_arithmetic_and_logical_operations.h" && cp -f "/usr/local/cuda-10.0/include/nppi_color_conversion.h" "$(@D)/cuda/include/nppi_color_conversion.h" && cp -f "/usr/local/cuda-10.0/include/nppi_compression_functions.h" "$(@D)/cuda/include/nppi_compression_functions.h" && cp -f "/usr/local/cuda-10.0/include/nppi_computer_vision.h" "$(@D)/cuda/include/nppi_computer_vision.h" && cp -f "/usr/local/cuda-10.0/include/nppi_data_exchange_and_initialization.h" "$(@D)/cuda/include/nppi_data_exchange_and_initialization.h" && cp -f "/usr/local/cuda-10.0/include/nppi_filtering_functions.h" "$(@D)/cuda/include/nppi_filtering_functions.h" && cp -f "/usr/local/cuda-10.0/include/nppi_geometry_transforms.h" "$(@D)/cuda/include/nppi_geometry_transforms.h" && cp -f "/usr/local/cuda-10.0/include/nppi_linear_transforms.h" "$(@D)/cuda/include/nppi_linear_transforms.h" && cp -f "/usr/local/cuda-10.0/include/nppi_morphological_operations.h" "$(@D)/cuda/include/nppi_morphological_operations.h" && cp -f "/usr/local/cuda-10.0/include/nppi_statistics_functions.h" "$(@D)/cuda/include/nppi_statistics_functions.h" && cp -f "/usr/local/cuda-10.0/include/nppi_support_functions.h" "$(@D)/cuda/include/nppi_support_functions.h" && cp -f "/usr/local/cuda-10.0/include/nppi_threshold_and_compare_operations.h" "$(@D)/cuda/include/nppi_threshold_and_compare_operations.h" && cp -f "/usr/local/cuda-10.0/include/npps.h" "$(@D)/cuda/include/npps.h" && cp -f "/usr/local/cuda-10.0/include/npps_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/npps_arithmetic_and_logical_operations.h" && cp -f "/usr/local/cuda-10.0/include/npps_conversion_functions.h" "$(@D)/cuda/include/npps_conversion_functions.h" && cp -f "/usr/local/cuda-10.0/include/npps_filtering_functions.h" "$(@D)/cuda/include/npps_filtering_functions.h" && cp -f "/usr/local/cuda-10.0/include/npps_initialization.h" "$(@D)/cuda/include/npps_initialization.h" && cp -f "/usr/local/cuda-10.0/include/npps_statistics_functions.h" "$(@D)/cuda/include/npps_statistics_functions.h" && cp -f "/usr/local/cuda-10.0/include/npps_support_functions.h" "$(@D)/cuda/include/npps_support_functions.h" && cp -f "/usr/local/cuda-10.0/include/nppversion.h" "$(@D)/cuda/include/nppversion.h" && cp -f "/usr/local/cuda-10.0/include/nvToolsExt.h" "$(@D)/cuda/include/nvToolsExt.h" && cp -f "/usr/local/cuda-10.0/include/nvToolsExtCuda.h" "$(@D)/cuda/include/nvToolsExtCuda.h" && cp -f "/usr/local/cuda-10.0/include/nvToolsExtCudaRt.h" "$(@D)/cuda/include/nvToolsExtCudaRt.h" && cp -f "/usr/local/cuda-10.0/include/nvToolsExtMeta.h" "$(@D)/cuda/include/nvToolsExtMeta.h" && cp -f "/usr/local/cuda-10.0/include/nvToolsExtSync.h" "$(@D)/cuda/include/nvToolsExtSync.h" && cp -f "/usr/local/cuda-10.0/include/nvblas.h" "$(@D)/cuda/include/nvblas.h" && cp -f "/usr/local/cuda-10.0/include/nvfunctional" "$(@D)/cuda/include/nvfunctional" && cp -f "/usr/local/cuda-10.0/include/nvgraph.h" "$(@D)/cuda/include/nvgraph.h" && cp -f "/usr/local/cuda-10.0/include/nvjpeg.h" "$(@D)/cuda/include/nvjpeg.h" && cp -f "/usr/local/cuda-10.0/include/nvml.h" "$(@D)/cuda/include/nvml.h" && cp -f "/usr/local/cuda-10.0/include/nvrtc.h" "$(@D)/cuda/include/nvrtc.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvToolsExt.h" "$(@D)/cuda/include/nvtx3/nvToolsExt.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvToolsExtCuda.h" "$(@D)/cuda/include/nvtx3/nvToolsExtCuda.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvToolsExtCudaRt.h" "$(@D)/cuda/include/nvtx3/nvToolsExtCudaRt.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvToolsExtOpenCL.h" "$(@D)/cuda/include/nvtx3/nvToolsExtOpenCL.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvToolsExtSync.h" "$(@D)/cuda/include/nvtx3/nvToolsExtSync.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxImpl.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxImpl.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxImplCore.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxImplCore.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxInit.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxInit.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxInitDecls.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxInitDecls.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxInitDefs.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxInitDefs.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxLinkOnce.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxLinkOnce.h" && cp -f "/usr/local/cuda-10.0/include/nvtx3/nvtxDetail/nvtxTypes.h" "$(@D)/cuda/include/nvtx3/nvtxDetail/nvtxTypes.h" && cp -f "/usr/local/cuda-10.0/include/sm_20_atomic_functions.h" "$(@D)/cuda/include/sm_20_atomic_functions.h" && cp -f "/usr/local/cuda-10.0/include/sm_20_atomic_functions.hpp" "$(@D)/cuda/include/sm_20_atomic_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/sm_20_intrinsics.h" "$(@D)/cuda/include/sm_20_intrinsics.h" && cp -f "/usr/local/cuda-10.0/include/sm_20_intrinsics.hpp" "$(@D)/cuda/include/sm_20_intrinsics.hpp" && cp -f "/usr/local/cuda-10.0/include/sm_30_intrinsics.h" "$(@D)/cuda/include/sm_30_intrinsics.h" && cp -f "/usr/local/cuda-10.0/include/sm_30_intrinsics.hpp" "$(@D)/cuda/include/sm_30_intrinsics.hpp" && cp -f "/usr/local/cuda-10.0/include/sm_32_atomic_functions.h" "$(@D)/cuda/include/sm_32_atomic_functions.h" && cp -f "/usr/local/cuda-10.0/include/sm_32_atomic_functions.hpp" "$(@D)/cuda/include/sm_32_atomic_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/sm_32_intrinsics.h" "$(@D)/cuda/include/sm_32_intrinsics.h" && cp -f "/usr/local/cuda-10.0/include/sm_32_intrinsics.hpp" "$(@D)/cuda/include/sm_32_intrinsics.hpp" && cp -f "/usr/local/cuda-10.0/include/sm_35_atomic_functions.h" "$(@D)/cuda/include/sm_35_atomic_functions.h" && cp -f "/usr/local/cuda-10.0/include/sm_35_intrinsics.h" "$(@D)/cuda/include/sm_35_intrinsics.h" && cp -f "/usr/local/cuda-10.0/include/sm_60_atomic_functions.h" "$(@D)/cuda/include/sm_60_atomic_functions.h" && cp -f "/usr/local/cuda-10.0/include/sm_60_atomic_functions.hpp" "$(@D)/cuda/include/sm_60_atomic_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/sm_61_intrinsics.h" "$(@D)/cuda/include/sm_61_intrinsics.h" && cp -f "/usr/local/cuda-10.0/include/sm_61_intrinsics.hpp" "$(@D)/cuda/include/sm_61_intrinsics.hpp" && cp -f "/usr/local/cuda-10.0/include/sobol_direction_vectors.h" "$(@D)/cuda/include/sobol_direction_vectors.h" && cp -f "/usr/local/cuda-10.0/include/surface_functions.h" "$(@D)/cuda/include/surface_functions.h" && cp -f "/usr/local/cuda-10.0/include/surface_functions.hpp" "$(@D)/cuda/include/surface_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/surface_indirect_functions.h" "$(@D)/cuda/include/surface_indirect_functions.h" && cp -f "/usr/local/cuda-10.0/include/surface_indirect_functions.hpp" "$(@D)/cuda/include/surface_indirect_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/surface_types.h" "$(@D)/cuda/include/surface_types.h" && cp -f "/usr/local/cuda-10.0/include/texture_fetch_functions.h" "$(@D)/cuda/include/texture_fetch_functions.h" && cp -f "/usr/local/cuda-10.0/include/texture_fetch_functions.hpp" "$(@D)/cuda/include/texture_fetch_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/texture_indirect_functions.h" "$(@D)/cuda/include/texture_indirect_functions.h" && cp -f "/usr/local/cuda-10.0/include/texture_indirect_functions.hpp" "$(@D)/cuda/include/texture_indirect_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/texture_types.h" "$(@D)/cuda/include/texture_types.h" && cp -f "/usr/local/cuda-10.0/include/thrust/adjacent_difference.h" "$(@D)/cuda/include/thrust/adjacent_difference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/advance.h" "$(@D)/cuda/include/thrust/advance.h" && cp -f "/usr/local/cuda-10.0/include/thrust/binary_search.h" "$(@D)/cuda/include/thrust/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/complex.h" "$(@D)/cuda/include/thrust/complex.h" && cp -f "/usr/local/cuda-10.0/include/thrust/copy.h" "$(@D)/cuda/include/thrust/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/count.h" "$(@D)/cuda/include/thrust/count.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/detail/adjacent_difference.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/advance.inl" "$(@D)/cuda/include/thrust/detail/advance.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/alignment.h" "$(@D)/cuda/include/thrust/detail/alignment.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/allocator_traits.h" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/allocator_traits.inl" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/copy_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/copy_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/default_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/default_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/destroy_range.h" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/destroy_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/fill_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/fill_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/malloc_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/malloc_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/no_throw_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/no_throw_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/tagged_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/tagged_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/temporary_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/allocator/temporary_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/binary_search.inl" "$(@D)/cuda/include/thrust/detail/binary_search.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/arithmetic.h" "$(@D)/cuda/include/thrust/detail/complex/arithmetic.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/c99math.h" "$(@D)/cuda/include/thrust/detail/complex/c99math.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/catrig.h" "$(@D)/cuda/include/thrust/detail/complex/catrig.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/catrigf.h" "$(@D)/cuda/include/thrust/detail/complex/catrigf.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/ccosh.h" "$(@D)/cuda/include/thrust/detail/complex/ccosh.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/ccoshf.h" "$(@D)/cuda/include/thrust/detail/complex/ccoshf.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/cexp.h" "$(@D)/cuda/include/thrust/detail/complex/cexp.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/cexpf.h" "$(@D)/cuda/include/thrust/detail/complex/cexpf.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/clog.h" "$(@D)/cuda/include/thrust/detail/complex/clog.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/clogf.h" "$(@D)/cuda/include/thrust/detail/complex/clogf.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/complex.inl" "$(@D)/cuda/include/thrust/detail/complex/complex.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/cpow.h" "$(@D)/cuda/include/thrust/detail/complex/cpow.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/cproj.h" "$(@D)/cuda/include/thrust/detail/complex/cproj.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/csinh.h" "$(@D)/cuda/include/thrust/detail/complex/csinh.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/csinhf.h" "$(@D)/cuda/include/thrust/detail/complex/csinhf.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/csqrt.h" "$(@D)/cuda/include/thrust/detail/complex/csqrt.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/csqrtf.h" "$(@D)/cuda/include/thrust/detail/complex/csqrtf.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/ctanh.h" "$(@D)/cuda/include/thrust/detail/complex/ctanh.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/ctanhf.h" "$(@D)/cuda/include/thrust/detail/complex/ctanhf.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/math_private.h" "$(@D)/cuda/include/thrust/detail/complex/math_private.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/complex/stream.h" "$(@D)/cuda/include/thrust/detail/complex/stream.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config.h" "$(@D)/cuda/include/thrust/detail/config.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/compiler.h" "$(@D)/cuda/include/thrust/detail/config/compiler.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/compiler_fence.h" "$(@D)/cuda/include/thrust/detail/config/compiler_fence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/config.h" "$(@D)/cuda/include/thrust/detail/config/config.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/debug.h" "$(@D)/cuda/include/thrust/detail/config/debug.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/device_system.h" "$(@D)/cuda/include/thrust/detail/config/device_system.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/exec_check_disable.h" "$(@D)/cuda/include/thrust/detail/config/exec_check_disable.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/forceinline.h" "$(@D)/cuda/include/thrust/detail/config/forceinline.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/global_workarounds.h" "$(@D)/cuda/include/thrust/detail/config/global_workarounds.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/host_device.h" "$(@D)/cuda/include/thrust/detail/config/host_device.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/host_system.h" "$(@D)/cuda/include/thrust/detail/config/host_system.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/config/simple_defines.h" "$(@D)/cuda/include/thrust/detail/config/simple_defines.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/contiguous_storage.h" "$(@D)/cuda/include/thrust/detail/contiguous_storage.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/contiguous_storage.inl" "$(@D)/cuda/include/thrust/detail/contiguous_storage.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/copy.h" "$(@D)/cuda/include/thrust/detail/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/copy.inl" "$(@D)/cuda/include/thrust/detail/copy.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/copy_if.h" "$(@D)/cuda/include/thrust/detail/copy_if.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/copy_if.inl" "$(@D)/cuda/include/thrust/detail/copy_if.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/count.inl" "$(@D)/cuda/include/thrust/detail/count.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/cstdint.h" "$(@D)/cuda/include/thrust/detail/cstdint.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/device_delete.inl" "$(@D)/cuda/include/thrust/detail/device_delete.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/device_free.inl" "$(@D)/cuda/include/thrust/detail/device_free.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/device_malloc.inl" "$(@D)/cuda/include/thrust/detail/device_malloc.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/device_new.inl" "$(@D)/cuda/include/thrust/detail/device_new.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/device_ptr.inl" "$(@D)/cuda/include/thrust/detail/device_ptr.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/device_reference.inl" "$(@D)/cuda/include/thrust/detail/device_reference.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/device_vector.inl" "$(@D)/cuda/include/thrust/detail/device_vector.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/dispatch/is_trivial_copy.h" "$(@D)/cuda/include/thrust/detail/dispatch/is_trivial_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/distance.inl" "$(@D)/cuda/include/thrust/detail/distance.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/equal.inl" "$(@D)/cuda/include/thrust/detail/equal.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/execute_with_allocator.h" "$(@D)/cuda/include/thrust/detail/execute_with_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/execution_policy.h" "$(@D)/cuda/include/thrust/detail/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/extrema.inl" "$(@D)/cuda/include/thrust/detail/extrema.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/fill.inl" "$(@D)/cuda/include/thrust/detail/fill.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/find.inl" "$(@D)/cuda/include/thrust/detail/find.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/for_each.inl" "$(@D)/cuda/include/thrust/detail/for_each.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/function.h" "$(@D)/cuda/include/thrust/detail/function.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional.inl" "$(@D)/cuda/include/thrust/detail/functional.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/actor.h" "$(@D)/cuda/include/thrust/detail/functional/actor.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/actor.inl" "$(@D)/cuda/include/thrust/detail/functional/actor.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/argument.h" "$(@D)/cuda/include/thrust/detail/functional/argument.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/composite.h" "$(@D)/cuda/include/thrust/detail/functional/composite.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/operators/arithmetic_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/arithmetic_operators.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/operators/assignment_operator.h" "$(@D)/cuda/include/thrust/detail/functional/operators/assignment_operator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/operators/bitwise_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/bitwise_operators.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/operators/compound_assignment_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/operators/logical_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/logical_operators.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/operators/operator_adaptors.h" "$(@D)/cuda/include/thrust/detail/functional/operators/operator_adaptors.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/operators/relational_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/relational_operators.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/placeholder.h" "$(@D)/cuda/include/thrust/detail/functional/placeholder.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/functional/value.h" "$(@D)/cuda/include/thrust/detail/functional/value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/gather.inl" "$(@D)/cuda/include/thrust/detail/gather.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/generate.inl" "$(@D)/cuda/include/thrust/detail/generate.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/get_iterator_value.h" "$(@D)/cuda/include/thrust/detail/get_iterator_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/host_vector.inl" "$(@D)/cuda/include/thrust/detail/host_vector.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/inner_product.inl" "$(@D)/cuda/include/thrust/detail/inner_product.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/integer_math.h" "$(@D)/cuda/include/thrust/detail/integer_math.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/integer_traits.h" "$(@D)/cuda/include/thrust/detail/integer_traits.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/internal_functional.h" "$(@D)/cuda/include/thrust/detail/internal_functional.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/logical.inl" "$(@D)/cuda/include/thrust/detail/logical.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/detail/malloc_and_free.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/merge.inl" "$(@D)/cuda/include/thrust/detail/merge.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/minmax.h" "$(@D)/cuda/include/thrust/detail/minmax.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/mismatch.inl" "$(@D)/cuda/include/thrust/detail/mismatch.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/mpl/math.h" "$(@D)/cuda/include/thrust/detail/mpl/math.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/numeric_traits.h" "$(@D)/cuda/include/thrust/detail/numeric_traits.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/overlapped_copy.h" "$(@D)/cuda/include/thrust/detail/overlapped_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/pair.inl" "$(@D)/cuda/include/thrust/detail/pair.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/partition.inl" "$(@D)/cuda/include/thrust/detail/partition.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/pointer.h" "$(@D)/cuda/include/thrust/detail/pointer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/pointer.inl" "$(@D)/cuda/include/thrust/detail/pointer.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/preprocessor.h" "$(@D)/cuda/include/thrust/detail/preprocessor.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/range/head_flags.h" "$(@D)/cuda/include/thrust/detail/range/head_flags.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/range/tail_flags.h" "$(@D)/cuda/include/thrust/detail/range/tail_flags.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/raw_pointer_cast.h" "$(@D)/cuda/include/thrust/detail/raw_pointer_cast.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/raw_reference_cast.h" "$(@D)/cuda/include/thrust/detail/raw_reference_cast.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/reduce.inl" "$(@D)/cuda/include/thrust/detail/reduce.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/reference.h" "$(@D)/cuda/include/thrust/detail/reference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/reference.inl" "$(@D)/cuda/include/thrust/detail/reference.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/reference_forward_declaration.h" "$(@D)/cuda/include/thrust/detail/reference_forward_declaration.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/remove.inl" "$(@D)/cuda/include/thrust/detail/remove.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/replace.inl" "$(@D)/cuda/include/thrust/detail/replace.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/reverse.inl" "$(@D)/cuda/include/thrust/detail/reverse.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/scan.inl" "$(@D)/cuda/include/thrust/detail/scan.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/scatter.inl" "$(@D)/cuda/include/thrust/detail/scatter.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/seq.h" "$(@D)/cuda/include/thrust/detail/seq.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/sequence.inl" "$(@D)/cuda/include/thrust/detail/sequence.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/set_operations.inl" "$(@D)/cuda/include/thrust/detail/set_operations.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/sort.inl" "$(@D)/cuda/include/thrust/detail/sort.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/static_assert.h" "$(@D)/cuda/include/thrust/detail/static_assert.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/static_map.h" "$(@D)/cuda/include/thrust/detail/static_map.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/swap.h" "$(@D)/cuda/include/thrust/detail/swap.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/swap.inl" "$(@D)/cuda/include/thrust/detail/swap.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/swap_ranges.inl" "$(@D)/cuda/include/thrust/detail/swap_ranges.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/tabulate.inl" "$(@D)/cuda/include/thrust/detail/tabulate.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/temporary_array.h" "$(@D)/cuda/include/thrust/detail/temporary_array.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/temporary_array.inl" "$(@D)/cuda/include/thrust/detail/temporary_array.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/detail/temporary_buffer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/transform.inl" "$(@D)/cuda/include/thrust/detail/transform.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/transform_reduce.inl" "$(@D)/cuda/include/thrust/detail/transform_reduce.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/transform_scan.inl" "$(@D)/cuda/include/thrust/detail/transform_scan.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/trivial_sequence.h" "$(@D)/cuda/include/thrust/detail/trivial_sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/tuple.inl" "$(@D)/cuda/include/thrust/detail/tuple.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/tuple_meta_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_meta_transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/tuple_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" "$(@D)/cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/function_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/function_traits.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/has_member_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_member_function.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/has_nested_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_nested_type.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/has_trivial_assign.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_trivial_assign.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/is_call_possible.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_call_possible.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/is_metafunction_defined.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_metafunction_defined.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/iterator/is_output_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/minimum_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/minimum_type.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/pointer_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/pointer_traits.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/type_traits/result_of_adaptable_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_copy.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_fill.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/unique.inl" "$(@D)/cuda/include/thrust/detail/unique.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/use_default.h" "$(@D)/cuda/include/thrust/detail/use_default.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/util/align.h" "$(@D)/cuda/include/thrust/detail/util/align.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/util/blocking.h" "$(@D)/cuda/include/thrust/detail/util/blocking.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/vector_base.h" "$(@D)/cuda/include/thrust/detail/vector_base.h" && cp -f "/usr/local/cuda-10.0/include/thrust/detail/vector_base.inl" "$(@D)/cuda/include/thrust/detail/vector_base.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/device_allocator.h" "$(@D)/cuda/include/thrust/device_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_delete.h" "$(@D)/cuda/include/thrust/device_delete.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_free.h" "$(@D)/cuda/include/thrust/device_free.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_malloc.h" "$(@D)/cuda/include/thrust/device_malloc.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_malloc_allocator.h" "$(@D)/cuda/include/thrust/device_malloc_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_new.h" "$(@D)/cuda/include/thrust/device_new.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_new_allocator.h" "$(@D)/cuda/include/thrust/device_new_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_ptr.h" "$(@D)/cuda/include/thrust/device_ptr.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_reference.h" "$(@D)/cuda/include/thrust/device_reference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/device_vector.h" "$(@D)/cuda/include/thrust/device_vector.h" && cp -f "/usr/local/cuda-10.0/include/thrust/distance.h" "$(@D)/cuda/include/thrust/distance.h" && cp -f "/usr/local/cuda-10.0/include/thrust/equal.h" "$(@D)/cuda/include/thrust/equal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/execution_policy.h" "$(@D)/cuda/include/thrust/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/extrema.h" "$(@D)/cuda/include/thrust/extrema.h" && cp -f "/usr/local/cuda-10.0/include/thrust/fill.h" "$(@D)/cuda/include/thrust/fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/find.h" "$(@D)/cuda/include/thrust/find.h" && cp -f "/usr/local/cuda-10.0/include/thrust/for_each.h" "$(@D)/cuda/include/thrust/for_each.h" && cp -f "/usr/local/cuda-10.0/include/thrust/functional.h" "$(@D)/cuda/include/thrust/functional.h" && cp -f "/usr/local/cuda-10.0/include/thrust/gather.h" "$(@D)/cuda/include/thrust/gather.h" && cp -f "/usr/local/cuda-10.0/include/thrust/generate.h" "$(@D)/cuda/include/thrust/generate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/host_vector.h" "$(@D)/cuda/include/thrust/host_vector.h" && cp -f "/usr/local/cuda-10.0/include/thrust/inner_product.h" "$(@D)/cuda/include/thrust/inner_product.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/constant_iterator.h" "$(@D)/cuda/include/thrust/iterator/constant_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/counting_iterator.h" "$(@D)/cuda/include/thrust/iterator/counting_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/any_assign.h" "$(@D)/cuda/include/thrust/iterator/detail/any_assign.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/any_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/any_system_tag.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/constant_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/constant_iterator_base.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/counting_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/counting_iterator.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/device_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/device_system_tag.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/discard_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/discard_iterator_base.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/distance_from_result.h" "$(@D)/cuda/include/thrust/iterator/detail/distance_from_result.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/host_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/host_system_tag.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/is_iterator_category.h" "$(@D)/cuda/include/thrust/iterator/detail/is_iterator_category.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/is_trivial_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/is_trivial_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/iterator_adaptor_base.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_adaptor_base.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/iterator_category_to_system.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_system.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/iterator_category_to_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/iterator_facade_category.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_facade_category.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/iterator_traits.inl" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traits.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/iterator_traversal_tags.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traversal_tags.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/join_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/join_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/minimum_category.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_category.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/minimum_system.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_system.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/normal_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/normal_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/permutation_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/permutation_iterator_base.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/retag.h" "$(@D)/cuda/include/thrust/iterator/detail/retag.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/reverse_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/reverse_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator_base.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/tagged_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/tagged_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/transform_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_iterator.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/transform_output_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_output_iterator.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/tuple_of_iterator_references.h" "$(@D)/cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/universal_categories.h" "$(@D)/cuda/include/thrust/iterator/detail/universal_categories.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/zip_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/detail/zip_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator_base.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/discard_iterator.h" "$(@D)/cuda/include/thrust/iterator/discard_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/iterator_adaptor.h" "$(@D)/cuda/include/thrust/iterator/iterator_adaptor.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/iterator_categories.h" "$(@D)/cuda/include/thrust/iterator/iterator_categories.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/iterator_facade.h" "$(@D)/cuda/include/thrust/iterator/iterator_facade.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/iterator_traits.h" "$(@D)/cuda/include/thrust/iterator/iterator_traits.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/permutation_iterator.h" "$(@D)/cuda/include/thrust/iterator/permutation_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/retag.h" "$(@D)/cuda/include/thrust/iterator/retag.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/reverse_iterator.h" "$(@D)/cuda/include/thrust/iterator/reverse_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/transform_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/transform_output_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_output_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/iterator/zip_iterator.h" "$(@D)/cuda/include/thrust/iterator/zip_iterator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/logical.h" "$(@D)/cuda/include/thrust/logical.h" && cp -f "/usr/local/cuda-10.0/include/thrust/memory.h" "$(@D)/cuda/include/thrust/memory.h" && cp -f "/usr/local/cuda-10.0/include/thrust/merge.h" "$(@D)/cuda/include/thrust/merge.h" && cp -f "/usr/local/cuda-10.0/include/thrust/mismatch.h" "$(@D)/cuda/include/thrust/mismatch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/pair.h" "$(@D)/cuda/include/thrust/pair.h" && cp -f "/usr/local/cuda-10.0/include/thrust/partition.h" "$(@D)/cuda/include/thrust/partition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random.h" "$(@D)/cuda/include/thrust/random.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/discard_block_engine.inl" "$(@D)/cuda/include/thrust/random/detail/discard_block_engine.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/linear_congruential_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/linear_congruential_engine_discard.h" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine_discard.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/linear_feedback_shift_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/mod.h" "$(@D)/cuda/include/thrust/random/detail/mod.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/normal_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/normal_distribution.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/normal_distribution_base.h" "$(@D)/cuda/include/thrust/random/detail/normal_distribution_base.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/random_core_access.h" "$(@D)/cuda/include/thrust/random/detail/random_core_access.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/subtract_with_carry_engine.inl" "$(@D)/cuda/include/thrust/random/detail/subtract_with_carry_engine.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/uniform_int_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_int_distribution.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/uniform_real_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_real_distribution.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/xor_combine_engine.inl" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/random/detail/xor_combine_engine_max.h" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine_max.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/discard_block_engine.h" "$(@D)/cuda/include/thrust/random/discard_block_engine.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/linear_congruential_engine.h" "$(@D)/cuda/include/thrust/random/linear_congruential_engine.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/linear_feedback_shift_engine.h" "$(@D)/cuda/include/thrust/random/linear_feedback_shift_engine.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/normal_distribution.h" "$(@D)/cuda/include/thrust/random/normal_distribution.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/subtract_with_carry_engine.h" "$(@D)/cuda/include/thrust/random/subtract_with_carry_engine.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/uniform_int_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_int_distribution.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/uniform_real_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_real_distribution.h" && cp -f "/usr/local/cuda-10.0/include/thrust/random/xor_combine_engine.h" "$(@D)/cuda/include/thrust/random/xor_combine_engine.h" && cp -f "/usr/local/cuda-10.0/include/thrust/reduce.h" "$(@D)/cuda/include/thrust/reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/remove.h" "$(@D)/cuda/include/thrust/remove.h" && cp -f "/usr/local/cuda-10.0/include/thrust/replace.h" "$(@D)/cuda/include/thrust/replace.h" && cp -f "/usr/local/cuda-10.0/include/thrust/reverse.h" "$(@D)/cuda/include/thrust/reverse.h" && cp -f "/usr/local/cuda-10.0/include/thrust/scan.h" "$(@D)/cuda/include/thrust/scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/scatter.h" "$(@D)/cuda/include/thrust/scatter.h" && cp -f "/usr/local/cuda-10.0/include/thrust/sequence.h" "$(@D)/cuda/include/thrust/sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/set_operations.h" "$(@D)/cuda/include/thrust/set_operations.h" && cp -f "/usr/local/cuda-10.0/include/thrust/sort.h" "$(@D)/cuda/include/thrust/sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/swap.h" "$(@D)/cuda/include/thrust/swap.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cpp/detail/adjacent_difference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/assign_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cpp/detail/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy_if.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/count.h" "$(@D)/cuda/include/thrust/system/cpp/detail/count.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/equal.h" "$(@D)/cuda/include/thrust/system/cpp/detail/equal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cpp/detail/extrema.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/find.h" "$(@D)/cuda/include/thrust/system/cpp/detail/find.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cpp/detail/for_each.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/gather.h" "$(@D)/cuda/include/thrust/system/cpp/detail/gather.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/generate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/generate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/get_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cpp/detail/inner_product.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cpp/detail/iter_swap.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/logical.h" "$(@D)/cuda/include/thrust/system/cpp/detail/logical.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cpp/detail/malloc_and_free.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/memory.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/merge.h" "$(@D)/cuda/include/thrust/system/cpp/detail/merge.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cpp/detail/mismatch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/par.h" "$(@D)/cuda/include/thrust/system/cpp/detail/par.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/partition.h" "$(@D)/cuda/include/thrust/system/cpp/detail/partition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/remove.h" "$(@D)/cuda/include/thrust/system/cpp/detail/remove.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/replace.h" "$(@D)/cuda/include/thrust/system/cpp/detail/replace.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reverse.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scatter.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cpp/detail/set_operations.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/sort.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cpp/detail/swap_ranges.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/tabulate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cpp/detail/temporary_buffer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/transform.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/unique.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/vector.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/memory.h" "$(@D)/cuda/include/thrust/system/cpp/memory.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cpp/vector.h" "$(@D)/cuda/include/thrust/system/cpp/vector.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/config.h" "$(@D)/cuda/include/thrust/system/cuda/config.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/assign_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cuda/detail/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/core/agent_launcher.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/agent_launcher.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/core/alignment.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/alignment.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/core/triple_chevron_launch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/triple_chevron_launch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/core/util.h" "$(@D)/cuda/include/thrust/system/cuda/detail/core/util.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/count.h" "$(@D)/cuda/include/thrust/system/cuda/detail/count.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cross_system.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_histogram.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_downsweep.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_radix_sort_upsweep.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_reduce_by_key.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_rle.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_scan.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_segment_fixup.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_select_if.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/agent_spmv_orig.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/agent/single_pass_scan_operators.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_adjacent_difference.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_shuffle.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/block_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans2.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans3.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/cub.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/cub.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_radix_sort.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_segmented_reduce.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_spmv.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_histogram.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_radix_sort.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_reduce_by_key.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_rle.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_scan.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_select_if.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/dispatch_spmv_orig.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/host/mutex.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/host/mutex.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/discard_output_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_search.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/util_allocator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/util_arch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/util_debug.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/util_device.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_device.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/util_macro.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/util_namespace.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/util_ptx.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/util_type.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_type.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/equal.h" "$(@D)/cuda/include/thrust/system/cuda/detail/equal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/error.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/error.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extrema.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/find.h" "$(@D)/cuda/include/thrust/system/cuda/detail/find.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/gather.h" "$(@D)/cuda/include/thrust/system/cuda/detail/gather.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/generate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/generate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/get_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/guarded_driver_types.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_driver_types.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cuda/detail/inner_product.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/internal/copy_cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/internal/copy_cross_system.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/internal/copy_device_to_device.h" "$(@D)/cuda/include/thrust/system/cuda/detail/internal/copy_device_to_device.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cuda/detail/iter_swap.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/logical.h" "$(@D)/cuda/include/thrust/system/cuda/detail/logical.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cuda/detail/malloc_and_free.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/memory.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/mismatch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/par.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/par_to_seq.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par_to_seq.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/parallel_for.h" "$(@D)/cuda/include/thrust/system/cuda/detail/parallel_for.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/partition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/partition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/remove.h" "$(@D)/cuda/include/thrust/system/cuda/detail/remove.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/replace.h" "$(@D)/cuda/include/thrust/system/cuda/detail/replace.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reverse.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scatter.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cuda/detail/set_operations.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cuda/detail/swap_ranges.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/tabulate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_buffer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/terminate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/terminate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/transform.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/unique.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/util.h" "$(@D)/cuda/include/thrust/system/cuda/detail/util.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/vector.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/error.h" "$(@D)/cuda/include/thrust/system/cuda/error.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/experimental/pinned_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/experimental/pinned_allocator.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/memory.h" "$(@D)/cuda/include/thrust/system/cuda/memory.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/cuda/vector.h" "$(@D)/cuda/include/thrust/system/cuda/vector.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/adl/adjacent_difference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/assign_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/adl/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy_if.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/count.h" "$(@D)/cuda/include/thrust/system/detail/adl/count.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/equal.h" "$(@D)/cuda/include/thrust/system/detail/adl/equal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/extrema.h" "$(@D)/cuda/include/thrust/system/detail/adl/extrema.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/find.h" "$(@D)/cuda/include/thrust/system/detail/adl/find.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/for_each.h" "$(@D)/cuda/include/thrust/system/detail/adl/for_each.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/gather.h" "$(@D)/cuda/include/thrust/system/detail/adl/gather.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/generate.h" "$(@D)/cuda/include/thrust/system/detail/adl/generate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/get_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/get_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/adl/inner_product.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/adl/iter_swap.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/logical.h" "$(@D)/cuda/include/thrust/system/detail/adl/logical.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/adl/malloc_and_free.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/merge.h" "$(@D)/cuda/include/thrust/system/detail/adl/merge.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/adl/mismatch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/partition.h" "$(@D)/cuda/include/thrust/system/detail/adl/partition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/remove.h" "$(@D)/cuda/include/thrust/system/detail/adl/remove.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/replace.h" "$(@D)/cuda/include/thrust/system/detail/adl/replace.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/reverse.h" "$(@D)/cuda/include/thrust/system/detail/adl/reverse.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/scatter.h" "$(@D)/cuda/include/thrust/system/detail/adl/scatter.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/sequence.h" "$(@D)/cuda/include/thrust/system/detail/adl/sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/adl/set_operations.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/sort.h" "$(@D)/cuda/include/thrust/system/detail/adl/sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/adl/swap_ranges.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/adl/tabulate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/adl/temporary_buffer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/transform.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/unique.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/adl/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/bad_alloc.h" "$(@D)/cuda/include/thrust/system/detail/bad_alloc.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/errno.h" "$(@D)/cuda/include/thrust/system/detail/errno.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/error_category.inl" "$(@D)/cuda/include/thrust/system/detail/error_category.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/error_code.inl" "$(@D)/cuda/include/thrust/system/detail/error_code.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/error_condition.inl" "$(@D)/cuda/include/thrust/system/detail/error_condition.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/advance.h" "$(@D)/cuda/include/thrust/system/detail/generic/advance.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/advance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/advance.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/copy_if.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/count.h" "$(@D)/cuda/include/thrust/system/detail/generic/count.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/count.inl" "$(@D)/cuda/include/thrust/system/detail/generic/count.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/distance.h" "$(@D)/cuda/include/thrust/system/detail/generic/distance.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/distance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/distance.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/equal.h" "$(@D)/cuda/include/thrust/system/detail/generic/equal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/equal.inl" "$(@D)/cuda/include/thrust/system/detail/generic/equal.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/extrema.h" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/extrema.inl" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/find.h" "$(@D)/cuda/include/thrust/system/detail/generic/find.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/find.inl" "$(@D)/cuda/include/thrust/system/detail/generic/find.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/for_each.h" "$(@D)/cuda/include/thrust/system/detail/generic/for_each.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/gather.h" "$(@D)/cuda/include/thrust/system/detail/generic/gather.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/gather.inl" "$(@D)/cuda/include/thrust/system/detail/generic/gather.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/generate.h" "$(@D)/cuda/include/thrust/system/detail/generic/generate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/generate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/generate.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/inner_product.inl" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/logical.h" "$(@D)/cuda/include/thrust/system/detail/generic/logical.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/memory.h" "$(@D)/cuda/include/thrust/system/detail/generic/memory.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/memory.inl" "$(@D)/cuda/include/thrust/system/detail/generic/memory.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/merge.h" "$(@D)/cuda/include/thrust/system/detail/generic/merge.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/merge.inl" "$(@D)/cuda/include/thrust/system/detail/generic/merge.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/mismatch.inl" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/partition.h" "$(@D)/cuda/include/thrust/system/detail/generic/partition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/partition.inl" "$(@D)/cuda/include/thrust/system/detail/generic/partition.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/remove.h" "$(@D)/cuda/include/thrust/system/detail/generic/remove.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/remove.inl" "$(@D)/cuda/include/thrust/system/detail/generic/remove.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/replace.h" "$(@D)/cuda/include/thrust/system/detail/generic/replace.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/replace.inl" "$(@D)/cuda/include/thrust/system/detail/generic/replace.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/reverse.h" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/reverse.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/scalar/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/scalar/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/scan_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/scatter.h" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/scatter.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/select_system.h" "$(@D)/cuda/include/thrust/system/detail/generic/select_system.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/sequence.h" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/sequence.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/set_operations.inl" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/sort.h" "$(@D)/cuda/include/thrust/system/detail/generic/sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/sort.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sort.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/swap_ranges.inl" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/tabulate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/tag.h" "$(@D)/cuda/include/thrust/system/detail/generic/tag.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/temporary_buffer.inl" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/transform.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/transform.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/transform_reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/transform_scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/type_traits.h" "$(@D)/cuda/include/thrust/system/detail/generic/type_traits.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/unique.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/unique.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/generic/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/internal/decompose.h" "$(@D)/cuda/include/thrust/system/detail/internal/decompose.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/sequential/adjacent_difference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/assign_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/sequential/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/copy.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/copy_backward.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_backward.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_if.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/count.h" "$(@D)/cuda/include/thrust/system/detail/sequential/count.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/equal.h" "$(@D)/cuda/include/thrust/system/detail/sequential/equal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/execution_policy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/extrema.h" "$(@D)/cuda/include/thrust/system/detail/sequential/extrema.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/find.h" "$(@D)/cuda/include/thrust/system/detail/sequential/find.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/for_each.h" "$(@D)/cuda/include/thrust/system/detail/sequential/for_each.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/gather.h" "$(@D)/cuda/include/thrust/system/detail/sequential/gather.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/general_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/general_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/generate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/generate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/get_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/get_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/sequential/inner_product.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/insertion_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/insertion_sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/sequential/iter_swap.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/logical.h" "$(@D)/cuda/include/thrust/system/detail/sequential/logical.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/sequential/malloc_and_free.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/merge.h" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/merge.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/sequential/mismatch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/partition.h" "$(@D)/cuda/include/thrust/system/detail/sequential/partition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/remove.h" "$(@D)/cuda/include/thrust/system/detail/sequential/remove.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/replace.h" "$(@D)/cuda/include/thrust/system/detail/sequential/replace.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/reverse.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reverse.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/scatter.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scatter.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/sequence.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/sequential/set_operations.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/sequential/swap_ranges.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/tabulate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/sequential/temporary_buffer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/transform.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/trivial_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/trivial_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/unique.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/sequential/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/detail/system_error.inl" "$(@D)/cuda/include/thrust/system/detail/system_error.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/error_code.h" "$(@D)/cuda/include/thrust/system/error_code.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/omp/detail/adjacent_difference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/assign_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/omp/detail/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/copy.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/count.h" "$(@D)/cuda/include/thrust/system/omp/detail/count.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/equal.h" "$(@D)/cuda/include/thrust/system/omp/detail/equal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/detail/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/omp/detail/extrema.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/find.h" "$(@D)/cuda/include/thrust/system/omp/detail/find.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/gather.h" "$(@D)/cuda/include/thrust/system/omp/detail/gather.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/generate.h" "$(@D)/cuda/include/thrust/system/omp/detail/generate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/get_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/omp/detail/inner_product.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/omp/detail/iter_swap.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/logical.h" "$(@D)/cuda/include/thrust/system/omp/detail/logical.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/omp/detail/malloc_and_free.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/omp/detail/memory.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/merge.h" "$(@D)/cuda/include/thrust/system/omp/detail/merge.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/omp/detail/mismatch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/par.h" "$(@D)/cuda/include/thrust/system/omp/detail/par.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/partition.h" "$(@D)/cuda/include/thrust/system/omp/detail/partition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/partition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/partition.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/remove.h" "$(@D)/cuda/include/thrust/system/omp/detail/remove.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/remove.inl" "$(@D)/cuda/include/thrust/system/omp/detail/remove.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/replace.h" "$(@D)/cuda/include/thrust/system/omp/detail/replace.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/omp/detail/reverse.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/omp/detail/scatter.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/omp/detail/sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/omp/detail/set_operations.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/sort.h" "$(@D)/cuda/include/thrust/system/omp/detail/sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/sort.inl" "$(@D)/cuda/include/thrust/system/omp/detail/sort.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/omp/detail/swap_ranges.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/omp/detail/tabulate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/omp/detail/temporary_buffer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/transform.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/unique.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/unique.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/omp/detail/vector.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/memory.h" "$(@D)/cuda/include/thrust/system/omp/memory.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/omp/vector.h" "$(@D)/cuda/include/thrust/system/omp/vector.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/system_error.h" "$(@D)/cuda/include/thrust/system/system_error.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/tbb/detail/adjacent_difference.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/assign_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/tbb/detail/binary_search.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/copy.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/count.h" "$(@D)/cuda/include/thrust/system/tbb/detail/count.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/equal.h" "$(@D)/cuda/include/thrust/system/tbb/detail/equal.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/extrema.h" "$(@D)/cuda/include/thrust/system/tbb/detail/extrema.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/find.h" "$(@D)/cuda/include/thrust/system/tbb/detail/find.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/for_each.h" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/gather.h" "$(@D)/cuda/include/thrust/system/tbb/detail/gather.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/generate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/generate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/get_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/get_value.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/tbb/detail/inner_product.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/tbb/detail/iter_swap.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/logical.h" "$(@D)/cuda/include/thrust/system/tbb/detail/logical.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/tbb/detail/malloc_and_free.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/memory.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/memory.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/merge.h" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/merge.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/tbb/detail/mismatch.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/par.h" "$(@D)/cuda/include/thrust/system/tbb/detail/par.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/partition.h" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/partition.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_intervals.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/remove.h" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/remove.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/replace.h" "$(@D)/cuda/include/thrust/system/tbb/detail/replace.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/reverse.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reverse.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/scan.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/scatter.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scatter.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/sequence.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sequence.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/tbb/detail/set_operations.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/sort.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/sort.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/tbb/detail/swap_ranges.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/tabulate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/tbb/detail/temporary_buffer.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/transform.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/unique.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/unique.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/detail/vector.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/vector.inl" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/execution_policy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/memory.h" "$(@D)/cuda/include/thrust/system/tbb/memory.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system/tbb/vector.h" "$(@D)/cuda/include/thrust/system/tbb/vector.h" && cp -f "/usr/local/cuda-10.0/include/thrust/system_error.h" "$(@D)/cuda/include/thrust/system_error.h" && cp -f "/usr/local/cuda-10.0/include/thrust/tabulate.h" "$(@D)/cuda/include/thrust/tabulate.h" && cp -f "/usr/local/cuda-10.0/include/thrust/transform.h" "$(@D)/cuda/include/thrust/transform.h" && cp -f "/usr/local/cuda-10.0/include/thrust/transform_reduce.h" "$(@D)/cuda/include/thrust/transform_reduce.h" && cp -f "/usr/local/cuda-10.0/include/thrust/transform_scan.h" "$(@D)/cuda/include/thrust/transform_scan.h" && cp -f "/usr/local/cuda-10.0/include/thrust/tuple.h" "$(@D)/cuda/include/thrust/tuple.h" && cp -f "/usr/local/cuda-10.0/include/thrust/uninitialized_copy.h" "$(@D)/cuda/include/thrust/uninitialized_copy.h" && cp -f "/usr/local/cuda-10.0/include/thrust/uninitialized_fill.h" "$(@D)/cuda/include/thrust/uninitialized_fill.h" && cp -f "/usr/local/cuda-10.0/include/thrust/unique.h" "$(@D)/cuda/include/thrust/unique.h" && cp -f "/usr/local/cuda-10.0/include/thrust/version.h" "$(@D)/cuda/include/thrust/version.h" && cp -f "/usr/local/cuda-10.0/include/vector_functions.h" "$(@D)/cuda/include/vector_functions.h" && cp -f "/usr/local/cuda-10.0/include/vector_functions.hpp" "$(@D)/cuda/include/vector_functions.hpp" && cp -f "/usr/local/cuda-10.0/include/vector_types.h" "$(@D)/cuda/include/vector_types.h"
+ """,
+)
+
+genrule(
+ name = "cuda-nvvm",
+ outs = [
+ "cuda/nvvm/libdevice/libdevice.10.bc",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp -f "/usr/local/cuda-10.0/nvvm/libdevice/libdevice.10.bc" "$(@D)//libdevice.10.bc"
+ """,
+)
+
+genrule(
+ name = "cuda-extras",
+ outs = [
+ "cuda/extras/CUPTI/include/GL/gl.h",
+ "cuda/extras/CUPTI/include/GL/glew.h",
+ "cuda/extras/CUPTI/include/GL/glext.h",
+ "cuda/extras/CUPTI/include/GL/glu.h",
+ "cuda/extras/CUPTI/include/GL/glut.h",
+ "cuda/extras/CUPTI/include/GL/glx.h",
+ "cuda/extras/CUPTI/include/GL/glxext.h",
+ "cuda/extras/CUPTI/include/GL/wglew.h",
+ "cuda/extras/CUPTI/include/GL/wglext.h",
+ "cuda/extras/CUPTI/include/cuda_stdint.h",
+ "cuda/extras/CUPTI/include/cupti.h",
+ "cuda/extras/CUPTI/include/cupti_activity.h",
+ "cuda/extras/CUPTI/include/cupti_callbacks.h",
+ "cuda/extras/CUPTI/include/cupti_driver_cbid.h",
+ "cuda/extras/CUPTI/include/cupti_events.h",
+ "cuda/extras/CUPTI/include/cupti_metrics.h",
+ "cuda/extras/CUPTI/include/cupti_nvtx_cbid.h",
+ "cuda/extras/CUPTI/include/cupti_result.h",
+ "cuda/extras/CUPTI/include/cupti_runtime_cbid.h",
+ "cuda/extras/CUPTI/include/cupti_version.h",
+ "cuda/extras/CUPTI/include/generated_cudaGL_meta.h",
+ "cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h",
+ "cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h",
+ "cuda/extras/CUPTI/include/generated_cuda_meta.h",
+ "cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h",
+ "cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h",
+ "cuda/extras/CUPTI/include/generated_nvtx_meta.h",
+ "cuda/extras/CUPTI/include/openacc/cupti_openacc.h",
+ "cuda/extras/CUPTI/include/openmp/cupti_openmp.h",
+ "cuda/extras/CUPTI/include/openmp/ompt.h",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/gl.h" "$(@D)/cuda/extras/CUPTI/include/GL/gl.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/glew.h" "$(@D)/cuda/extras/CUPTI/include/GL/glew.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/glext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glext.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/glu.h" "$(@D)/cuda/extras/CUPTI/include/GL/glu.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/glut.h" "$(@D)/cuda/extras/CUPTI/include/GL/glut.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/glx.h" "$(@D)/cuda/extras/CUPTI/include/GL/glx.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/glxext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glxext.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/wglew.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglew.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/GL/wglext.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglext.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cuda_stdint.h" "$(@D)/cuda/extras/CUPTI/include/cuda_stdint.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti.h" "$(@D)/cuda/extras/CUPTI/include/cupti.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_activity.h" "$(@D)/cuda/extras/CUPTI/include/cupti_activity.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_callbacks.h" "$(@D)/cuda/extras/CUPTI/include/cupti_callbacks.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_driver_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_driver_cbid.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_events.h" "$(@D)/cuda/extras/CUPTI/include/cupti_events.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_metrics.h" "$(@D)/cuda/extras/CUPTI/include/cupti_metrics.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_nvtx_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_nvtx_cbid.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_result.h" "$(@D)/cuda/extras/CUPTI/include/cupti_result.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_runtime_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_runtime_cbid.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/cupti_version.h" "$(@D)/cuda/extras/CUPTI/include/cupti_version.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/generated_cudaGL_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaGL_meta.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/generated_cudaVDPAU_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/generated_cuda_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_meta.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/generated_nvtx_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_nvtx_meta.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/openacc/cupti_openacc.h" "$(@D)/cuda/extras/CUPTI/include/openacc/cupti_openacc.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/openmp/cupti_openmp.h" "$(@D)/cuda/extras/CUPTI/include/openmp/cupti_openmp.h" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/include/openmp/ompt.h" "$(@D)/cuda/extras/CUPTI/include/openmp/ompt.h"
+ """,
+)
+
+genrule(
+ name = "cuda-lib",
+ outs = [
+ "cuda/lib/libcuda.so",
+ "cuda/lib/libcudart.so.10.0",
+ "cuda/lib/libcudart_static.a",
+ "cuda/lib/libcublas.so.10.0",
+ "cuda/lib/libcusolver.so.10.0",
+ "cuda/lib/libcurand.so.10.0",
+ "cuda/lib/libcufft.so.10.0",
+ "cuda/lib/libcudnn.so.7",
+ "cuda/lib/libcupti.so.10.0",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp -f "/usr/local/cuda-10.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp -f "/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudart.so.10.0.130" "$(@D)/cuda/lib/libcudart.so.10.0" && cp -f "/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp -f "/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcublas.so.10.0.130" "$(@D)/cuda/lib/libcublas.so.10.0" && cp -f "/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcusolver.so.10.0.130" "$(@D)/cuda/lib/libcusolver.so.10.0" && cp -f "/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcurand.so.10.0.130" "$(@D)/cuda/lib/libcurand.so.10.0" && cp -f "/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcufft.so.10.0.145" "$(@D)/cuda/lib/libcufft.so.10.0" && cp -f "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.3.1" "$(@D)/cuda/lib/libcudnn.so.7" && cp -f "/usr/local/cuda-10.0/extras/CUPTI/lib64/libcupti.so.10.0.130" "$(@D)/cuda/lib/libcupti.so.10.0"
+ """,
+)
+
+genrule(
+ name = "cudnn-include",
+ outs = [
+ "cuda/include/cudnn.h",
+ ],
+ cmd = """
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp -f "/usr/include/cudnn.h" "$(@D)/cudnn.h"
+ """,
+)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl b/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
new file mode 100755
index 0000000..a53c891
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
@@ -0,0 +1,31 @@
+# Macros for building CUDA code.
+def if_cuda(if_true, if_false = []):
+ """Shorthand for select()'ing on whether we're building with CUDA.
+
+ Returns a select statement which evaluates to if_true if we're building
+ with CUDA enabled. Otherwise, the select statement evaluates to if_false.
+
+ """
+ return select({
+ "@local_config_cuda//cuda:using_nvcc": if_true,
+ "@local_config_cuda//cuda:using_clang": if_true,
+ "//conditions:default": if_false,
+ })
+
+def cuda_default_copts():
+ """Default options for all CUDA compilations."""
+ return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + [])
+
+def cuda_is_configured():
+ """Returns true if CUDA was enabled during the configure process."""
+ return True
+
+def if_cuda_is_configured(x):
+ """Tests if the CUDA was enabled during the configure process.
+
+ Unlike if_cuda(), this does not require that we are building with
+ --config=cuda. Used to allow non-CUDA code to depend on CUDA libraries.
+ """
+ if cuda_is_configured():
+ return x
+ return []
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/cuda/cuda_config.h b/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/cuda/cuda_config.h
new file mode 100755
index 0000000..0934618
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/cuda/cuda_config.h
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef CUDA_CUDA_CONFIG_H_
+#define CUDA_CUDA_CONFIG_H_
+
+#define TF_CUDA_CAPABILITIES CudaVersion("3.0")
+
+#define TF_CUDA_VERSION "10.0"
+#define TF_CUDNN_VERSION "7"
+
+#define TF_CUDA_TOOLKIT_PATH "/usr/local/cuda-10.0"
+
+#endif // CUDA_CUDA_CONFIG_H_
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
new file mode 100755
index 0000000..6442e76
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
@@ -0,0 +1,87 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+toolchain(
+ name = "toolchain-linux-x86_64",
+ exec_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ target_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ toolchain = ":cc-compiler-local",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
+cc_toolchain_suite(
+ name = "toolchain",
+ toolchains = {
+ "local|compiler": ":cc-compiler-local",
+ "darwin|compiler": ":cc-compiler-darwin",
+ "x64_windows|msvc-cl": ":cc-compiler-windows",
+ },
+)
+
+cc_toolchain(
+ name = "cc-compiler-local",
+ all_files = ":crosstool_wrapper_driver_is_not_gcc",
+ compiler_files = ":empty",
+ cpu = "local",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":crosstool_wrapper_driver_is_not_gcc",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ # To support linker flags that need to go to the start of command line
+ # we need the toolchain to support parameter files. Parameter files are
+ # last on the command line and contain all shared libraries to link, so all
+ # regular options will be left of them.
+ supports_param_files = 1,
+)
+
+cc_toolchain(
+ name = "cc-compiler-darwin",
+ all_files = ":crosstool_wrapper_driver_is_not_gcc",
+ compiler_files = ":empty",
+ cpu = "darwin",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":crosstool_wrapper_driver_is_not_gcc",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 0,
+)
+
+cc_toolchain(
+ name = "cc-compiler-windows",
+ all_files = ":windows_msvc_wrapper_files",
+ compiler_files = ":empty",
+ cpu = "x64_windows",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":windows_msvc_wrapper_files",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 1,
+)
+
+filegroup(
+ name = "empty",
+ srcs = [],
+)
+
+filegroup(
+ name = "crosstool_wrapper_driver_is_not_gcc",
+ srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"],
+)
+
+filegroup(
+ name = "windows_msvc_wrapper_files",
+ srcs = glob(["windows/msvc_*"]),
+)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/CROSSTOOL b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/CROSSTOOL
new file mode 100755
index 0000000..1c2e8bc
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/CROSSTOOL
@@ -0,0 +1,1431 @@
+major_version: "local"
+minor_version: ""
+default_target_cpu: "same_as_host"
+
+default_toolchain {
+ cpu: "k8"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "piii"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "arm"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "darwin"
+ toolchain_identifier: "local_darwin"
+}
+default_toolchain {
+ cpu: "ppc"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "x64_windows"
+ toolchain_identifier: "local_windows"
+}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ target_libc: "local"
+ target_cpu: "local"
+ target_system_name: "local"
+ toolchain_identifier: "local_linux"
+
+ feature {
+ name: "c++11"
+ flag_set {
+ action: "c++-compile"
+ flag_group {
+ flag: "-std=c++11"
+ }
+ }
+ }
+
+ feature {
+ name: "stdlib"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-lstdc++"
+ }
+ }
+ }
+
+ feature {
+ name: "determinism"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ flag: "-Wno-builtin-macro-redefined"
+ flag: "-D__DATE__=\"redacted\""
+ flag: "-D__TIMESTAMP__=\"redacted\""
+ flag: "-D__TIME__=\"redacted\""
+ }
+ }
+ }
+
+ feature {
+ name: "alwayslink"
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-Wl,-no-as-needed"
+ }
+ }
+ }
+
+ # This feature will be enabled for builds that support pic by bazel.
+ feature {
+ name: "pic"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ expand_if_all_available: "pic"
+ flag: "-fPIC"
+ }
+ flag_group {
+ expand_if_none_available: "pic"
+ flag: "-fPIE"
+ }
+ }
+ }
+
+ # Security hardening on by default.
+ feature {
+ name: "hardening"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now
+ # have it enabled by default.
+ flag: "-U_FORTIFY_SOURCE"
+ flag: "-D_FORTIFY_SOURCE=1"
+ flag: "-fstack-protector"
+ }
+ }
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-Wl,-z,relro,-z,now"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-pie"
+ flag: "-Wl,-z,relro,-z,now"
+ }
+ }
+ }
+
+ feature {
+ name: "warnings"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # All warnings are enabled. Maybe enable -Werror as well?
+ flag: "-Wall"
+
+ }
+ }
+ }
+
+ # Keep stack frames for debugging, even in opt mode.
+ feature {
+ name: "frame-pointer"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-fno-omit-frame-pointer"
+ }
+ }
+ }
+
+ feature {
+ name: "build-id"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ # Stamp the binary with a unique identifier.
+ flag: "-Wl,--build-id=md5"
+ flag: "-Wl,--hash-style=gnu"
+ }
+ }
+ }
+
+ feature {
+ name: "no-canonical-prefixes"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-no-canonical-prefixes"
+ flag: "-fno-canonical-system-headers"
+ }
+ }
+ }
+
+ feature {
+ name: "disable-assertions"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: "linker-bin-path"
+
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-B/usr/bin"
+ }
+ }
+ }
+
+ feature {
+ name: "common"
+ implies: "stdlib"
+ implies: "c++11"
+ implies: "determinism"
+ implies: "alwayslink"
+ implies: "hardening"
+ implies: "warnings"
+ implies: "frame-pointer"
+ implies: "build-id"
+ implies: "no-canonical-prefixes"
+ implies: "linker-bin-path"
+ }
+
+ feature {
+ name: "opt"
+ implies: "common"
+ implies: "disable-assertions"
+
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt
+ # or even generally? However, that can't happen here, as it requires
+ # special handling in Bazel.
+ flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ flag: "-O2"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ flag: "-ffunction-sections"
+ flag: "-fdata-sections"
+ }
+ }
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-Wl,--gc-sections"
+ }
+ }
+ }
+
+ feature {
+ name: "fastbuild"
+ implies: "common"
+ }
+
+ feature {
+ name: "dbg"
+ implies: "common"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-g"
+ }
+ }
+ }
+
+ # Set clang as a C/C++ compiler.
+ tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_is_not_gcc" }
+
+ # Use the default system toolchain for everything else.
+ tool_path { name: "ar" path: "/usr/bin/ar" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Enabled dynamic linking.
+ linking_mode_flags { mode: DYNAMIC }
+
+ cxx_builtin_include_directory: "/usr/include/c++/4.8"
+ cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu/c++/4.8"
+ cxx_builtin_include_directory: "/usr/include/c++/4.8/backward"
+ cxx_builtin_include_directory: "/usr/lib/gcc/x86_64-linux-gnu/4.8/include"
+ cxx_builtin_include_directory: "/usr/local/include"
+ cxx_builtin_include_directory: "/usr/lib/gcc/x86_64-linux-gnu/4.8/include-fixed"
+ cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu"
+ cxx_builtin_include_directory: "/usr/include"
+ cxx_builtin_include_directory: "/usr/local/cuda-10.0/targets/x86_64-linux/include"
+ cxx_builtin_include_directory: "/usr/local/cuda-10.0/include"
+ cxx_builtin_include_directory: "/usr/local/cuda-10.0/extras/CUPTI/include"
+ cxx_builtin_include_directory: "/usr/include"
+}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ target_libc: "macosx"
+ target_cpu: "darwin"
+ target_system_name: "local"
+ toolchain_identifier: "local_darwin"
+ feature {
+ name: "c++11"
+ flag_set {
+ action: "c++-compile"
+ flag_group {
+ flag: "-std=c++11"
+ }
+ }
+ }
+
+ feature {
+ name: "stdlib"
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-lc++"
+ }
+ }
+ }
+
+ feature {
+ name: "determinism"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ flag: "-Wno-builtin-macro-redefined"
+ flag: "-D__DATE__=\"redacted\""
+ flag: "-D__TIMESTAMP__=\"redacted\""
+ flag: "-D__TIME__=\"redacted\""
+ }
+ }
+ }
+
+ # This feature will be enabled for builds that support pic by bazel.
+ feature {
+ name: "pic"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ expand_if_all_available: "pic"
+ flag: "-fPIC"
+ }
+ flag_group {
+ expand_if_none_available: "pic"
+ flag: "-fPIE"
+ }
+ }
+ }
+
+ # Security hardening on by default.
+ feature {
+ name: "hardening"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now
+ # have it enabled by default.
+ flag: "-U_FORTIFY_SOURCE"
+ flag: "-D_FORTIFY_SOURCE=1"
+ flag: "-fstack-protector"
+ }
+ }
+ flag_set {
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-pie"
+ }
+ }
+ }
+
+ feature {
+ name: "warnings"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # All warnings are enabled. Maybe enable -Werror as well?
+ flag: "-Wall"
+
+ }
+ }
+ }
+
+ # Keep stack frames for debugging, even in opt mode.
+ feature {
+ name: "frame-pointer"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-fno-omit-frame-pointer"
+ }
+ }
+ }
+
+ feature {
+ name: "no-canonical-prefixes"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag:"-no-canonical-prefixes"
+ }
+ }
+ }
+
+ feature {
+ name: "disable-assertions"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: "linker-bin-path"
+
+ flag_set {
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "-B/usr/bin"
+ }
+ }
+ }
+
+ feature {
+ name: "undefined-dynamic"
+ flag_set {
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-executable"
+ flag_group {
+ flag: "-undefined"
+ flag: "dynamic_lookup"
+ }
+ }
+ }
+
+ feature {
+ name: "common"
+ implies: "stdlib"
+ implies: "c++11"
+ implies: "determinism"
+ implies: "hardening"
+ implies: "warnings"
+ implies: "frame-pointer"
+ implies: "no-canonical-prefixes"
+ implies: "linker-bin-path"
+ implies: "undefined-dynamic"
+ }
+
+ feature {
+ name: "opt"
+ implies: "common"
+ implies: "disable-assertions"
+
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt
+ # or even generally? However, that can't happen here, as it requires
+ # special handling in Bazel.
+ flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ flag: "-O2"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ flag: "-ffunction-sections"
+ flag: "-fdata-sections"
+ }
+ }
+ }
+
+ feature {
+ name: "fastbuild"
+ implies: "common"
+ }
+
+ feature {
+ name: "dbg"
+ implies: "common"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ flag_group {
+ flag: "-g"
+ }
+ }
+ }
+
+ # Set clang as a C/C++ compiler.
+ tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_is_not_gcc" }
+
+ # Use the default system toolchain for everything else.
+ tool_path { name: "ar" path: "/usr/bin/libtool" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Enabled dynamic linking.
+ linking_mode_flags { mode: DYNAMIC }
+
+ cxx_builtin_include_directory: "/usr/include/c++/4.8"
+ cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu/c++/4.8"
+ cxx_builtin_include_directory: "/usr/include/c++/4.8/backward"
+ cxx_builtin_include_directory: "/usr/lib/gcc/x86_64-linux-gnu/4.8/include"
+ cxx_builtin_include_directory: "/usr/local/include"
+ cxx_builtin_include_directory: "/usr/lib/gcc/x86_64-linux-gnu/4.8/include-fixed"
+ cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu"
+ cxx_builtin_include_directory: "/usr/include"
+ cxx_builtin_include_directory: "/usr/local/cuda-10.0/targets/x86_64-linux/include"
+ cxx_builtin_include_directory: "/usr/local/cuda-10.0/include"
+ cxx_builtin_include_directory: "/usr/local/cuda-10.0/extras/CUPTI/include"
+ cxx_builtin_include_directory: "/usr/include"
+}
+
+toolchain {
+ toolchain_identifier: "local_windows"
+ host_system_name: "local"
+ target_system_name: "local"
+
+ abi_version: "local"
+ abi_libc_version: "local"
+ target_cpu: "x64_windows"
+ compiler: "msvc-cl"
+ target_libc: "msvcrt"
+
+
+
+ tool_path {
+ name: "ar"
+ path: ""
+ }
+ tool_path {
+ name: "ml"
+ path: ""
+ }
+ tool_path {
+ name: "cpp"
+ path: ""
+ }
+ tool_path {
+ name: "gcc"
+ path: ""
+ }
+ tool_path {
+ name: "gcov"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "ld"
+ path: ""
+ }
+ tool_path {
+ name: "nm"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objcopy"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "objdump"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ tool_path {
+ name: "strip"
+ path: "wrapper/bin/msvc_nop.bat"
+ }
+ supports_interface_shared_objects: true
+
+ # TODO(pcloudy): Review those flags below, they should be defined by cl.exe
+ compiler_flag: "/DCOMPILER_MSVC"
+
+ # Don't define min/max macros in windows.h.
+ compiler_flag: "/DNOMINMAX"
+
+ # Platform defines.
+ compiler_flag: "/D_WIN32_WINNT=0x0600"
+ # Turn off warning messages.
+ compiler_flag: "/D_CRT_SECURE_NO_DEPRECATE"
+ compiler_flag: "/D_CRT_SECURE_NO_WARNINGS"
+ compiler_flag: "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS"
+
+ # Useful options to have on for compilation.
+ # Increase the capacity of object files to 2^32 sections.
+ compiler_flag: "/bigobj"
+ # Allocate 500MB for precomputed headers.
+ compiler_flag: "/Zm500"
+ # Use unsigned char by default.
+ compiler_flag: "/J"
+ # Use function level linking.
+ compiler_flag: "/Gy"
+ # Use string pooling.
+ compiler_flag: "/GF"
+ # Catch C++ exceptions only and tell the compiler to assume that functions declared
+ # as extern "C" never throw a C++ exception.
+ compiler_flag: "/EHsc"
+
+ # Globally disabled warnings.
+ # Don't warn about elements of array being be default initialized.
+ compiler_flag: "/wd4351"
+ # Don't warn about no matching delete found.
+ compiler_flag: "/wd4291"
+ # Don't warn about diamond inheritance patterns.
+ compiler_flag: "/wd4250"
+ # Don't warn about insecure functions (e.g. non _s functions).
+ compiler_flag: "/wd4996"
+
+ linker_flag: "/MACHINE:X64"
+
+ feature {
+ name: "no_legacy_features"
+ }
+
+ # Suppress startup banner.
+ feature {
+ name: "nologo"
+ flag_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ flag_group {
+ flag: "/nologo"
+ }
+ }
+ }
+
+ feature {
+ name: 'has_configured_linker_path'
+ }
+
+ # This feature indicates strip is not supported, building stripped binary will just result a copy of orignial binary
+ feature {
+ name: 'no_stripping'
+ }
+
+ # This feature indicates this is a toolchain targeting Windows.
+ feature {
+ name: 'targets_windows'
+ implies: 'copy_dynamic_libraries_to_binary'
+ enabled: true
+ }
+
+ feature {
+ name: 'copy_dynamic_libraries_to_binary'
+ }
+
+ action_config {
+ config_name: 'assemble'
+ action_name: 'assemble'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'preprocess-assemble'
+ action_name: 'preprocess-assemble'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'sysroot'
+ }
+
+ action_config {
+ config_name: 'c-compile'
+ action_name: 'c-compile'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-compile'
+ action_name: 'c++-compile'
+ tool {
+ tool_path: ''
+ }
+ implies: 'compiler_input_flags'
+ implies: 'compiler_output_flags'
+ implies: 'legacy_compile_flags'
+ implies: 'nologo'
+ implies: 'msvc_env'
+ implies: 'parse_showincludes'
+ implies: 'user_compile_flags'
+ implies: 'sysroot'
+ implies: 'unfiltered_compile_flags'
+ }
+
+ action_config {
+ config_name: 'c++-link-executable'
+ action_name: 'c++-link-executable'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ }
+
+ action_config {
+ config_name: 'c++-link-dynamic-library'
+ action_name: 'c++-link-dynamic-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-nodeps-dynamic-library'
+ action_name: 'c++-link-nodeps-dynamic-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'shared_flag'
+ implies: 'linkstamps'
+ implies: 'output_execpath_flags'
+ implies: 'input_param_flags'
+ implies: 'user_link_flags'
+ implies: 'legacy_link_flags'
+ implies: 'linker_subsystem_flag'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ implies: 'no_stripping'
+ implies: 'has_configured_linker_path'
+ implies: 'def_file'
+ }
+
+ action_config {
+ config_name: 'c++-link-static-library'
+ action_name: 'c++-link-static-library'
+ tool {
+ tool_path: ''
+ }
+ implies: 'nologo'
+ implies: 'archiver_flags'
+ implies: 'input_param_flags'
+ implies: 'linker_param_file'
+ implies: 'msvc_env'
+ }
+
+ # TODO(b/65151735): Remove legacy_compile_flags feature when legacy fields are
+ # not used in this crosstool
+ feature {
+ name: 'legacy_compile_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'legacy_compile_flags'
+ flag: '%{legacy_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: "msvc_env"
+ env_set {
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-module-compile"
+ action: "c++-module-codegen"
+ action: "c++-header-parsing"
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c++-link-executable"
+ action: "c++-link-dynamic-library"
+ action: "c++-link-nodeps-dynamic-library"
+ action: "c++-link-static-library"
+ env_entry {
+ key: "PATH"
+ value: ""
+ }
+ env_entry {
+ key: "INCLUDE"
+ value: ""
+ }
+ env_entry {
+ key: "LIB"
+ value: ""
+ }
+ env_entry {
+ key: "TMP"
+ value: ""
+ }
+ env_entry {
+ key: "TEMP"
+ value: ""
+ }
+ }
+ }
+
+ feature {
+ name: 'include_paths'
+ flag_set {
+ action: "assemble"
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ flag_group {
+ iterate_over: 'quote_include_paths'
+ flag: '/I%{quote_include_paths}'
+ }
+ flag_group {
+ iterate_over: 'include_paths'
+ flag: '/I%{include_paths}'
+ }
+ flag_group {
+ iterate_over: 'system_include_paths'
+ flag: '/I%{system_include_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: "preprocessor_defines"
+ flag_set {
+ action: "assemble"
+ action: "preprocess-assemble"
+ action: "c-compile"
+ action: "c++-compile"
+ action: "c++-header-parsing"
+ action: "c++-module-compile"
+ flag_group {
+ flag: "/D%{preprocessor_defines}"
+ iterate_over: "preprocessor_defines"
+ }
+ }
+ }
+
+ # Tell Bazel to parse the output of /showIncludes
+ feature {
+ name: 'parse_showincludes'
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-module-compile'
+ action: 'c++-header-parsing'
+ flag_group {
+ flag: "/showIncludes"
+ }
+ }
+ }
+
+
+ feature {
+ name: 'generate_pdb_file'
+ requires: {
+ feature: 'dbg'
+ }
+ requires: {
+ feature: 'fastbuild'
+ }
+ }
+
+ feature {
+ name: 'shared_flag'
+ flag_set {
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/DLL'
+ }
+ }
+ }
+
+ feature {
+ name: 'linkstamps'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ expand_if_all_available: 'linkstamp_paths'
+ flag_group {
+ iterate_over: 'linkstamp_paths'
+ flag: '%{linkstamp_paths}'
+ }
+ }
+ }
+
+ feature {
+ name: 'output_execpath_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'archiver_flags'
+ flag_set {
+ expand_if_all_available: 'output_execpath'
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '/OUT:%{output_execpath}'
+ }
+ }
+ }
+
+ feature {
+ name: 'input_param_flags'
+ flag_set {
+ expand_if_all_available: 'interface_library_output_path'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/IMPLIB:%{interface_library_output_path}"
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libopts'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'libopts'
+ flag: '%{libopts}'
+ }
+ }
+ flag_set {
+ expand_if_all_available: 'libraries_to_link'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ iterate_over: 'libraries_to_link'
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file_group'
+ }
+ iterate_over: 'libraries_to_link.object_files'
+ flag_group {
+ flag: '%{libraries_to_link.object_files}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'object_file'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'interface_library'
+ }
+ flag_group {
+ flag: '%{libraries_to_link.name}'
+ }
+ }
+ flag_group {
+ expand_if_equal: {
+ variable: 'libraries_to_link.type'
+ value: 'static_library'
+ }
+ flag_group {
+ expand_if_false: 'libraries_to_link.is_whole_archive'
+ flag: '%{libraries_to_link.name}'
+ }
+ flag_group {
+ expand_if_true: 'libraries_to_link.is_whole_archive'
+ flag: '/WHOLEARCHIVE:%{libraries_to_link.name}'
+ }
+ }
+ }
+ }
+ }
+
+ # Since this feature is declared earlier in the CROSSTOOL than
+ # "user_link_flags", this feature will be applied prior to it anwyhere they
+ # are both implied. And since "user_link_flags" contains the linkopts from
+ # the build rule, this allows the user to override the /SUBSYSTEM in the BUILD
+ # file.
+ feature {
+ name: 'linker_subsystem_flag'
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: '/SUBSYSTEM:CONSOLE'
+ }
+ }
+ }
+
+ # The "user_link_flags" contains user-defined linkopts (from build rules)
+ # so it should be defined after features that declare user-overridable flags.
+ # For example the "linker_subsystem_flag" defines a default "/SUBSYSTEM" flag
+ # but we want to let the user override it, therefore "link_flag_subsystem" is
+ # defined earlier in the CROSSTOOL file than "user_link_flags".
+ feature {
+ name: 'user_link_flags'
+ flag_set {
+ expand_if_all_available: 'user_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'user_link_flags'
+ flag: '%{user_link_flags}'
+ }
+ }
+ }
+ feature {
+ name: 'legacy_link_flags'
+ flag_set {
+ expand_if_all_available: 'legacy_link_flags'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'legacy_link_flags'
+ flag: '%{legacy_link_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'linker_param_file'
+ flag_set {
+ expand_if_all_available: 'linker_param_file'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ action: 'c++-link-static-library'
+ flag_group {
+ flag: '@%{linker_param_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'static_link_msvcrt'
+ }
+
+ feature {
+ name: 'static_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MT"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_no_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MD"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrt.lib"
+ }
+ }
+ requires: { feature: 'fastbuild'}
+ requires: { feature: 'opt'}
+ }
+
+ feature {
+ name: 'static_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MTd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:libcmtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dynamic_link_msvcrt_debug'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/MDd"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEFAULTLIB:msvcrtd.lib"
+ }
+ }
+ requires: { feature: 'dbg'}
+ }
+
+ feature {
+ name: 'dbg'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FULL"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'fastbuild'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/Od"
+ flag: "/Z7"
+ flag: "/DDEBUG"
+ }
+ }
+ flag_set {
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEBUG:FASTLINK"
+ flag: "/INCREMENTAL:NO"
+ }
+ }
+ implies: 'generate_pdb_file'
+ }
+
+ feature {
+ name: 'opt'
+ flag_set {
+ action: 'c-compile'
+ action: 'c++-compile'
+ flag_group {
+ flag: "/O2"
+ flag: "/DNDEBUG"
+ }
+ }
+ }
+
+ feature {
+ name: 'user_compile_flags'
+ flag_set {
+ expand_if_all_available: 'user_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'user_compile_flags'
+ flag: '%{user_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'sysroot'
+ flag_set {
+ expand_if_all_available: 'sysroot'
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ iterate_over: 'sysroot'
+ flag: '--sysroot=%{sysroot}'
+ }
+ }
+ }
+
+ feature {
+ name: 'unfiltered_compile_flags'
+ flag_set {
+ expand_if_all_available: 'unfiltered_compile_flags'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ iterate_over: 'unfiltered_compile_flags'
+ flag: '%{unfiltered_compile_flags}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_output_flags'
+ flag_set {
+ action: 'assemble'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ flag: '/Zi'
+ }
+ }
+ flag_set {
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_none_available: 'output_assembly_file'
+ expand_if_none_available: 'output_preprocess_file'
+ flag: '/Fo%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_assembly_file'
+ flag: '/Fa%{output_file}'
+ }
+ flag_group {
+ expand_if_all_available: 'output_file'
+ expand_if_all_available: 'output_preprocess_file'
+ flag: '/P'
+ flag: '/Fi%{output_file}'
+ }
+ }
+ }
+
+ feature {
+ name: 'compiler_input_flags'
+ flag_set {
+ action: 'assemble'
+ action: 'preprocess-assemble'
+ action: 'c-compile'
+ action: 'c++-compile'
+ action: 'c++-header-parsing'
+ action: 'c++-module-compile'
+ action: 'c++-module-codegen'
+ flag_group {
+ expand_if_all_available: 'source_file'
+ flag: '/c'
+ flag: '%{source_file}'
+ }
+ }
+ }
+
+ feature {
+ name : 'def_file',
+ flag_set {
+ expand_if_all_available: 'def_file_path'
+ action: 'c++-link-executable'
+ action: 'c++-link-dynamic-library'
+ action: "c++-link-nodeps-dynamic-library"
+ flag_group {
+ flag: "/DEF:%{def_file_path}"
+ # We can specify a different DLL name in DEF file, /ignore:4070 suppresses
+ # the warning message about DLL name doesn't match the default one.
+ # See https://msdn.microsoft.com/en-us/library/sfkk2fz7.aspx
+ flag: "/ignore:4070"
+ }
+ }
+ }
+
+ feature {
+ name: 'windows_export_all_symbols'
+ }
+
+ feature {
+ name: 'no_windows_export_all_symbols'
+ }
+
+ linking_mode_flags { mode: DYNAMIC }
+}
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/clang/bin/crosstool_wrapper_driver_is_not_gcc b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/clang/bin/crosstool_wrapper_driver_is_not_gcc
new file mode 100755
index 0000000..7ae59e9
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/clang/bin/crosstool_wrapper_driver_is_not_gcc
@@ -0,0 +1,264 @@
+#!/usr/bin/env python
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Crosstool wrapper for compiling CUDA programs.
+
+SYNOPSIS:
+ crosstool_wrapper_is_not_gcc [options passed in by cc_library()
+ or cc_binary() rule]
+
+DESCRIPTION:
+ This script is expected to be called by the cc_library() or cc_binary() bazel
+ rules. When the option "-x cuda" is present in the list of arguments passed
+ to this script, it invokes the nvcc CUDA compiler. Most arguments are passed
+ as is as a string to --compiler-options of nvcc. When "-x cuda" is not
+ present, this wrapper invokes hybrid_driver_is_not_gcc with the input
+ arguments as is.
+
+NOTES:
+ Changes to the contents of this file must be propagated from
+ //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc to
+ //third_party/gpus/crosstool/v*/*/clang/bin/crosstool_wrapper_is_not_gcc
+"""
+
+from __future__ import print_function
+
+__author__ = 'keveman@google.com (Manjunath Kudlur)'
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by cuda_autoconf.
+CPU_COMPILER = ('/usr/bin/gcc')
+GCC_HOST_COMPILER_PATH = ('/usr/bin/gcc')
+
+NVCC_PATH = '/usr/local/cuda-10.0/bin/nvcc'
+PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
+NVCC_VERSION = '10.0'
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from the argv list.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ option: The option whose value to extract, without the leading '-'.
+
+ Returns:
+ A list of values, either directly following the option,
+ (eg., -opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., -opt val1 -opt val2).
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-' + option, nargs='*', action='append')
+ args, _ = parser.parse_known_args(argv)
+ if not args or not vars(args)[option]:
+ return []
+ else:
+ return sum(vars(args)[option], [])
+
+
+def GetHostCompilerOptions(argv):
+ """Collect the -isystem, -iquote, and --sysroot option values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be used as the --compiler-options to nvcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-isystem', nargs='*', action='append')
+ parser.add_argument('-iquote', nargs='*', action='append')
+ parser.add_argument('--sysroot', nargs=1)
+ parser.add_argument('-g', nargs='*', action='append')
+ parser.add_argument('-fno-canonical-system-headers', action='store_true')
+
+ args, _ = parser.parse_known_args(argv)
+
+ opts = ''
+
+ if args.isystem:
+ opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, []))
+ if args.iquote:
+ opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
+ if args.g:
+ opts += ' -g' + ' -g'.join(sum(args.g, []))
+ if args.fno_canonical_system_headers:
+ opts += ' -fno-canonical-system-headers'
+ if args.sysroot:
+ opts += ' --sysroot ' + args.sysroot[0]
+
+ return opts
+
+def _update_options(nvcc_options):
+ if NVCC_VERSION in ("7.0",):
+ return nvcc_options
+
+ update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" }
+ return [ update_options[opt] if opt in update_options else opt
+ for opt in nvcc_options ]
+
+def GetNvccOptions(argv):
+ """Collect the -nvcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be passed directly to nvcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-nvcc_options', nargs='*', action='append')
+
+ args, _ = parser.parse_known_args(argv)
+
+ if args.nvcc_options:
+ options = _update_options(sum(args.nvcc_options, []))
+ return ' '.join(['--'+a for a in options])
+ return ''
+
+
+def InvokeNvcc(argv, log=False):
+ """Call nvcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('nvcc ' + args)
+ """
+
+ host_compiler_options = GetHostCompilerOptions(argv)
+ nvcc_compiler_options = GetNvccOptions(argv)
+ opt_option = GetOptionValue(argv, 'O')
+ m_options = GetOptionValue(argv, 'm')
+ m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
+ include_options = GetOptionValue(argv, 'I')
+ out_file = GetOptionValue(argv, 'o')
+ depfiles = GetOptionValue(argv, 'MF')
+ defines = GetOptionValue(argv, 'D')
+ defines = ''.join([' -D' + define for define in defines])
+ undefines = GetOptionValue(argv, 'U')
+ undefines = ''.join([' -U' + define for define in undefines])
+ std_options = GetOptionValue(argv, 'std')
+ # currently only c++11 is supported by Cuda 7.0 std argument
+ nvcc_allowed_std_options = ["c++11"]
+ std_options = ''.join([' -std=' + define
+ for define in std_options if define in nvcc_allowed_std_options])
+
+ # The list of source files get passed after the -c option. I don't know of
+ # any other reliable way to just get the list of source files to be compiled.
+ src_files = GetOptionValue(argv, 'c')
+
+ # Pass -w through from host to nvcc, but don't do anything fancier with
+ # warnings-related flags, since they're not necessarily the same across
+ # compilers.
+ warning_options = ' -w' if '-w' in argv else ''
+
+ if len(src_files) == 0:
+ return 1
+ if len(out_file) != 1:
+ return 1
+
+ opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0)
+ else ' -g -G')
+
+ includes = (' -I ' + ' -I '.join(include_options)
+ if len(include_options) > 0
+ else '')
+
+ # Unfortunately, there are other options that have -c prefix too.
+ # So allowing only those look like C/C++ files.
+ src_files = [f for f in src_files if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ srcs = ' '.join(src_files)
+ out = ' -o ' + out_file[0]
+
+ supported_cuda_compute_capabilities = [ "3.0" ]
+ nvccopts = '-D_FORCE_INLINES '
+ for capability in supported_cuda_compute_capabilities:
+ capability = capability.replace('.', '')
+ nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
+ capability, capability, capability)
+ nvccopts += ' ' + nvcc_compiler_options
+ nvccopts += undefines
+ nvccopts += defines
+ nvccopts += std_options
+ nvccopts += m_options
+ nvccopts += warning_options
+
+ if depfiles:
+ # Generate the dependency file
+ depfile = depfiles[0]
+ cmd = (NVCC_PATH + ' ' + nvccopts +
+ ' --compiler-options "' + host_compiler_options + '"' +
+ ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH +
+ ' -I .' +
+ ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile)
+ if log: Log(cmd)
+ exit_status = os.system(cmd)
+ if exit_status != 0:
+ return exit_status
+
+ cmd = (NVCC_PATH + ' ' + nvccopts +
+ ' --compiler-options "' + host_compiler_options + ' -fPIC"' +
+ ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH +
+ ' -I .' +
+ ' -x cu ' + opt + includes + ' -c ' + srcs + out)
+
+ # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
+ # Need to investigate and fix.
+ cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd
+ if log: Log(cmd)
+ return os.system(cmd)
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--cuda_log', action='store_true')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'cuda':
+ if args.cuda_log: Log('-x cuda')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.cuda_log: Log('using nvcc')
+ return InvokeNvcc(leftover, log=args.cuda_log)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--cuda_log'))]
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.bat b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.bat
new file mode 100755
index 0000000..e896e65
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.bat
@@ -0,0 +1,20 @@
+:: Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+::
+:: Licensed under the Apache License, Version 2.0 (the "License");
+:: you may not use this file except in compliance with the License.
+:: You may obtain a copy of the License at
+::
+:: http://www.apache.org/licenses/LICENSE-2.0
+::
+:: Unless required by applicable law or agreed to in writing, software
+:: distributed under the License is distributed on an "AS IS" BASIS,
+:: WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+:: See the License for the specific language governing permissions and
+:: limitations under the License.
+:: =============================================================================
+
+:: Invoke msvc_wrapper_for_nvcc.py, which is located in the same directory.
+@echo OFF
+set arg0=%~0
+for %%F in ("%arg0%") do set DRIVER_BIN=%%~dpF
+"/usr/bin/python3" -B "%DRIVER_BIN%\msvc_wrapper_for_nvcc.py" %*
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.py b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.py
new file mode 100755
index 0000000..0048395
--- /dev/null
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/windows/msvc_wrapper_for_nvcc.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Crosstool wrapper for compiling CUDA programs with nvcc on Windows.
+
+DESCRIPTION:
+ This script is the Windows version of //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc
+"""
+
+from __future__ import print_function
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by cuda_autoconf.
+CPU_COMPILER = ('/usr/bin/gcc')
+GCC_HOST_COMPILER_PATH = ('/usr/bin/gcc')
+
+NVCC_PATH = '/usr/local/cuda-10.0/bin/nvcc'
+NVCC_VERSION = '10.0'
+NVCC_TEMP_DIR = "C:\\Windows\\Temp\\nvcc_inter_files_tmp_dir"
+supported_cuda_compute_capabilities = [ "3.0" ]
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from options.
+
+ Args:
+ option: The option whose value to extract, without the leading '/'.
+
+ Returns:
+ 1. A list of values, either directly following the option,
+ (eg., /opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., /opt val1 /opt val2).
+ 2. The leftover options.
+ """
+
+ parser = ArgumentParser(prefix_chars='/')
+ parser.add_argument('/' + option, nargs='*', action='append')
+ args, leftover = parser.parse_known_args(argv)
+ if args and vars(args)[option]:
+ return (sum(vars(args)[option], []), leftover)
+ return ([], leftover)
+
+def _update_options(nvcc_options):
+ if NVCC_VERSION in ("7.0",):
+ return nvcc_options
+
+ update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" }
+ return [ update_options[opt] if opt in update_options else opt
+ for opt in nvcc_options ]
+
+def GetNvccOptions(argv):
+ """Collect the -nvcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ 1. The string that can be passed directly to nvcc.
+ 2. The leftover options.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-nvcc_options', nargs='*', action='append')
+
+ args, leftover = parser.parse_known_args(argv)
+
+ if args.nvcc_options:
+ options = _update_options(sum(args.nvcc_options, []))
+ return (['--' + a for a in options], leftover)
+ return ([], leftover)
+
+
+def InvokeNvcc(argv, log=False):
+ """Call nvcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('nvcc ' + args)
+ """
+
+ src_files = [f for f in argv if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ if len(src_files) == 0:
+ raise Error('No source files found for cuda compilation.')
+
+ out_file = [ f for f in argv if f.startswith('/Fo') ]
+ if len(out_file) != 1:
+ raise Error('Please sepecify exactly one output file for cuda compilation.')
+ out = ['-o', out_file[0][len('/Fo'):]]
+
+ nvcc_compiler_options, argv = GetNvccOptions(argv)
+
+ opt_option, argv = GetOptionValue(argv, 'O')
+ opt = ['-g', '-G']
+ if (len(opt_option) > 0 and opt_option[0] != 'd'):
+ opt = ['-O2']
+
+ include_options, argv = GetOptionValue(argv, 'I')
+ includes = ["-I " + include for include in include_options]
+
+ defines, argv = GetOptionValue(argv, 'D')
+ defines = ['-D' + define for define in defines]
+
+ undefines, argv = GetOptionValue(argv, 'U')
+ undefines = ['-U' + define for define in undefines]
+
+ # The rest of the unrecongized options should be passed to host compiler
+ host_compiler_options = [option for option in argv if option not in (src_files + out_file)]
+
+ m_options = ["-m64"]
+
+ nvccopts = ['-D_FORCE_INLINES']
+ for capability in supported_cuda_compute_capabilities:
+ capability = capability.replace('.', '')
+ nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % (
+ capability, capability, capability)]
+ nvccopts += nvcc_compiler_options
+ nvccopts += undefines
+ nvccopts += defines
+ nvccopts += m_options
+ nvccopts += ['--compiler-options="' + " ".join(host_compiler_options) + '"']
+ nvccopts += ['-x', 'cu'] + opt + includes + out + ['-c'] + src_files
+ # If we don't specify --keep-dir, nvcc will generate intermediate files under TEMP
+ # Put them under NVCC_TEMP_DIR instead, then Bazel can ignore files under NVCC_TEMP_DIR during dependency check
+ # http://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver
+ # Different actions are sharing NVCC_TEMP_DIR, so we cannot remove it if the directory already exists.
+ if os.path.isfile(NVCC_TEMP_DIR):
+ os.remove(NVCC_TEMP_DIR)
+ if not os.path.exists(NVCC_TEMP_DIR):
+ os.makedirs(NVCC_TEMP_DIR)
+ nvccopts += ['--keep', '--keep-dir', NVCC_TEMP_DIR]
+ cmd = [NVCC_PATH] + nvccopts
+ if log:
+ Log(cmd)
+ proc = subprocess.Popen(cmd,
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ env=os.environ.copy(),
+ shell=True)
+ proc.wait()
+ return proc.returncode
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--cuda_log', action='store_true')
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'cuda':
+ if args.cuda_log: Log('-x cuda')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.cuda_log: Log('using nvcc')
+ return InvokeNvcc(leftover, log=args.cuda_log)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--cuda_log'))
+ and not flag.startswith(('-nvcc_options'))]
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 8c2052e..1fdf51f 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -72,6 +72,7 @@
build:nohdfs --define=no_hdfs_support=true
build:nokafka --define=no_kafka_support=true
build:noignite --define=no_ignite_support=true
+build:nonccl --define=no_nccl_support=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true