Merge commit for internal changes
diff --git a/configure.py b/configure.py
index 05e3e80..cf16ef4 100644
--- a/configure.py
+++ b/configure.py
@@ -1122,12 +1122,16 @@
write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
computecpp_toolkit_path)
+
def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR."""
- ask_trisycl_include_dir = (
- 'Please specify the location of the triSYCL include directory. (Use '
- '--config=sycl_trisycl when building with Bazel) '
- '[Default is %s]: ') % _DEFAULT_TRISYCL_INCLUDE_DIR
+
+ ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
+ 'include directory. (Use --config=sycl_trisycl '
+ 'when building with Bazel) '
+ '[Default is %s]: '
+ ) % (_DEFAULT_TRISYCL_INCLUDE_DIR)
+
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
@@ -1201,46 +1205,10 @@
raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
-def set_mkl():
- write_to_bazelrc('build:mkl --define using_mkl=true')
- write_to_bazelrc('build:mkl -c opt')
- print(
- 'Add "--config=mkl" to your bazel command to build with MKL '
- 'support.\nPlease note that MKL on MacOS or windows is still not '
- 'supported.\nIf you would like to use a local MKL instead of '
- 'downloading, please set the environment variable \"TF_MKL_ROOT\" every '
- 'time before build.\n')
-
-
-def set_monolithic():
- # Add --config=monolithic to your bazel command to use a mostly-static
- # build and disable modular op registration support (this will revert to
- # loading TensorFlow with RTLD_GLOBAL in Python). By default (without
- # --config=monolithic), TensorFlow will build with a dependence on
- # //tensorflow:libtensorflow_framework.so.
- write_to_bazelrc('build:monolithic --define framework_shared_object=false')
- # For projects which use TensorFlow as part of a Bazel build process, putting
- # nothing in a bazelrc will default to a monolithic build. The following line
- # opts in to modular op registration support by default:
- write_to_bazelrc('build --define framework_shared_object=true')
-
-
-def create_android_bazelrc_configs():
- # Flags for --config=android
- write_to_bazelrc('build:android --crosstool_top=//external:android/crosstool')
- write_to_bazelrc(
- 'build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain')
- # Flags for --config=android_arm
- write_to_bazelrc('build:android_arm --config=android')
- write_to_bazelrc('build:android_arm --cpu=armeabi-v7a')
- # Flags for --config=android_arm64
- write_to_bazelrc('build:android_arm64 --config=android')
- write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a')
-
-
def set_grpc_build_flags():
write_to_bazelrc('build --define grpc_no_ares=true')
+
def set_windows_build_flags():
if is_windows():
# The non-monolithic build is not supported yet
@@ -1251,6 +1219,11 @@
write_to_bazelrc('build --verbose_failures')
+def config_info_line(name, help_text):
+ """Helper function to print formatted help text for Bazel config options."""
+ print('\t--config=%-12s\t# %s' % (name, help_text))
+
+
def main():
# Make a copy of os.environ to be clear when functions and getting and setting
# environment variables.
@@ -1336,10 +1309,7 @@
set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
- set_mkl()
- set_monolithic()
set_windows_build_flags()
- create_android_bazelrc_configs()
if workspace_has_any_android_rule():
print('The WORKSPACE file has at least one of ["android_sdk_repository", '
@@ -1357,6 +1327,11 @@
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
+ print('Preconfigured Bazel build configs. You can use any of the below by '
+ 'adding "--config=<>" to your build command. See tools/bazel.rc for '
+ 'more details.')
+ config_info_line('mkl', 'Build with MKL support.')
+ config_info_line('monolithic', 'Config for mostly static monolithic build.')
if __name__ == '__main__':
main()
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 53df884..74190cb 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -117,3 +117,9 @@
"//tensorflow/core:lib",
],
)
+
+filegroup(
+ name = "headers",
+ srcs = ["c_api.h"],
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 589afb9..5a5e5fe 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -98,7 +98,10 @@
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
status->status = tensorflow::Status::OK();
- tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+ {
+ tensorflow::mutex_lock ml(ctx->cache_mu);
+ tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+ }
TF_Graph* graph = ctx->session->graph;
TF_DeleteSession(ctx->session, status);
TF_DeleteGraph(graph);
@@ -110,6 +113,11 @@
return TF_SessionListDevices(ctx->session, status);
}
+void TFE_ContextClearCaches(TFE_Context* ctx) {
+ tensorflow::mutex_lock ml(ctx->cache_mu);
+ tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+}
+
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
@@ -489,8 +497,11 @@
std::vector<tensorflow::Tensor> outputs(1);
const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
- tensorflow::KernelAndDevice* kernel =
- tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
+ tensorflow::KernelAndDevice* kernel;
+ {
+ tensorflow::tf_shared_lock l(ctx->cache_mu);
+ kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
+ }
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
@@ -506,6 +517,7 @@
delete kernel;
return;
}
+ tensorflow::mutex_lock ml(ctx->cache_mu);
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
std::vector<TFE_TensorHandle*> copied_tensors;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 7caab43..9b0fd03 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -17,6 +17,8 @@
#define TENSORFLOW_C_EAGER_C_API_H_
// C API extensions to experiment with eager execution of kernels.
+// WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be
+// stable and can change without notice.
#include "tensorflow/c/c_api.h"
@@ -87,6 +89,10 @@
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);
+// Clears the internal caches in the TFE context. Useful when reseeding random
+// ops.
+TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
+
// A handle to a tensor on a device.
//
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 11b7a51..55a04d4 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -58,9 +58,10 @@
// session->devices[i].
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
+ tensorflow::mutex cache_mu;
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
tensorflow::Fprint128Hasher>
- kernel_cache;
+ kernel_cache GUARDED_BY(cache_mu);
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
return pflr->GetFLR(d->name());
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 78777f3..f7c6cd2 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -248,7 +248,9 @@
tags = ["optonly"],
deps = [
":xla_test",
+ "//tensorflow/contrib/signal:signal_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
"//tensorflow/python:spectral_ops",
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index bdc38be..afb5fa4 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -21,8 +21,10 @@
import itertools
import numpy as np
+import scipy.signal as sps
from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.contrib.signal.python.ops import spectral_ops as signal
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import spectral_ops
@@ -76,6 +78,29 @@
value = sess.run(out, {ph: data})
self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
+ def testContribSignalSTFT(self):
+ ws = 512
+ hs = 128
+ dims = (ws * 20,)
+ shape = BATCH_DIMS + dims
+ data = np.arange(np.prod(shape)) / np.prod(dims)
+ np.random.seed(123)
+ np.random.shuffle(data)
+ data = np.reshape(data.astype(np.float32), shape)
+ window = sps.get_window("hann", ws)
+ expected = sps.stft(
+ data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2]
+ expected = np.swapaxes(expected, -1, -2)
+ expected *= window.sum() # scipy divides by window sum
+ with self.test_session() as sess:
+ with self.test_scope():
+ ph = array_ops.placeholder(
+ dtypes.as_dtype(data.dtype), shape=data.shape)
+ out = signal.stft(ph, ws, hs)
+
+ value = sess.run(out, {ph: data})
+ self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
+
def testFFT(self):
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
spectral_ops.fft)
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index dd67a1d..f76e214 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -37,6 +37,8 @@
namespace {
+using xla::StatusOr;
+
const char* const kArgOp = "_Arg";
const char* const kRetValOp = "_Retval";
@@ -76,6 +78,20 @@
std::unordered_set<Node*> nodes;
};
+// Comparison function used for sorting nodes consistently.
+// a) resource variables are last, and
+// b) sort lexicographically by name (for deterministic output).
+struct NodeCmp {
+ bool operator()(const Node* lhs, const Node* rhs) const {
+ bool lhs_is_resource =
+ lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
+ bool rhs_is_resource =
+ rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
+ return std::tie(lhs_is_resource, lhs->name()) <
+ std::tie(rhs_is_resource, rhs->name());
+ }
+};
+
// Returns a textual representation of the names of the nodes in the input.
template <typename T>
string NodesToString(const T& nodes) {
@@ -141,7 +157,7 @@
return Status::OK();
}
-xla::StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
+StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
Status status;
Node* inserted_node = graph->AddNode(node_def, &status);
if (!status.ok()) {
@@ -150,7 +166,7 @@
return inserted_node;
}
-xla::StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
+StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
NodeDef arg_def;
NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
builder.Attr("T", type);
@@ -159,7 +175,7 @@
return AddNode(arg_def, graph);
}
-xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
+StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
NodeDef ret_def;
ret_def.set_op(kRetValOp);
ret_def.set_name(strings::StrCat(kRetValOp, index));
@@ -310,16 +326,9 @@
}
frame->args = std::move(args);
- // Order the arguments so that:
- // a) resource variables are last, and
- // b) sort lexicographically by name (for deterministic output).
- std::sort(frame->args.begin(), frame->args.end(),
- [](const Arg& a, const Arg& b) {
- bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE);
- bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE);
- return std::tie(a_is_resource, a.enter->name()) <
- std::tie(b_is_resource, b.enter->name());
- });
+ std::sort(
+ frame->args.begin(), frame->args.end(),
+ [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); });
if (frame->loop_cond == nullptr) {
return errors::InvalidArgument("Loop ", frame->name,
@@ -542,18 +551,6 @@
// Returns a textual representation of the Branch b.
static string Branch_Name(FunctionalizeCond::Branch b);
- // Comparison function used for sorting nodes consistently.
- struct CondCmp {
- bool operator()(const Node* lhs, const Node* rhs) const {
- bool lhs_is_resource =
- lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
- bool rhs_is_resource =
- rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
- return std::tie(lhs_is_resource, lhs->name()) <
- std::tie(rhs_is_resource, rhs->name());
- }
- };
-
// Functionalize all the switch-merge nodes of a loop-free graph into XlaIf
// nodes. That is, attempt to transform every remaining switch and merge nodes
// in the graph into XlaIf nodes.
@@ -571,6 +568,13 @@
int count;
};
+ struct PredicateSwitches {
+ explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}
+
+ Node* predicate;
+ std::vector<Node*> switches;
+ };
+
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
: library_(library), graph_(graph) {}
@@ -579,18 +583,24 @@
// extract into XlaIf nodes.
Status FunctionalizeInternal();
- // Converts a Merge node to a XlaIf. This encapsulates the process of
- // extracting the bodies needed for the then and else branch, creates a XlaIf
- // node, removing the nodes of the branches from the graph and replacing the
- // merge node with a XlaIf.
- Status ConvertCorrespondingMergeToXlaIf(
- const std::vector<Node*>& switch_nodes,
- const std::vector<Node*>& merge_nodes, Node* predicate);
+ // Determines the branch_map (mapping from node to branch of cond) and
+ // frontier (the nodes where the cond ends).
+ StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
+ std::unordered_set<Node*>>>
+ DetermineBranchMapAndFrontier(const std::vector<Node*>& switches);
+
+ // Returns XlaIf node created from subgraph of merge and switch nodes. This
+ // encapsulates the process of extracting the bodies needed for the then and
+ // else branch, creates a XlaIf node, removing the nodes of the branches from
+ // the graph and replacing the merge node with a XlaIf.
+ StatusOr<Node*> ConvertToXlaIf(const std::vector<Node*>& switch_nodes,
+ const std::vector<Node*>& merge_nodes,
+ Node* predicate);
// Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
- xla::StatusOr<Node*> BuildAndAddXlaIfOp(
- const std::vector<Node*>& switch_nodes,
- const std::vector<Node*>& merge_nodes, Node* predicate);
+ StatusOr<Node*> BuildAndAddXlaIfOp(const std::vector<Node*>& switch_nodes,
+ const std::vector<Node*>& merge_nodes,
+ Node* predicate);
// Extracts a function body corresponding to the given input edge of the merge
// node.
@@ -605,18 +615,26 @@
// Adds all output edges from the `if_node`.
Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
- // Returns the switches of graph_ in postorder. Dead switch nodes are skipped
- // and removed from the graph.
- std::vector<Node*> DetermineSwitchOrder();
+ // Returns the switches of graph_ (along with grouping predicates) in
+ // postorder. Dead switch nodes are skipped and removed from the graph.
+ std::vector<PredicateSwitches> DeterminePredicateSwitchOrder();
// Update the state for destination based on the state of source and the node
// being updated.
Status Join(const ForwardFlowNode& src_state, const Node* dst,
ForwardFlowNode* dst_state);
- // Validates that the branch_map and frontier of nodes for the conditional
+ // Ensure that all nodes in the branch_map are dominated by the switch
+ // nodes. Returns nodes that are not dominated by the switches but are a
+ // control dependency of a node in the cond, and remove such control
+ // dependencies.
+ StatusOr<std::vector<Node*>> EnsureDominanceAndReturnNonDominatedControlNodes(
+ const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
+ const std::vector<Node*>& switches);
+
+ // Validates that the frontier of nodes for the conditional
// section are as expected.
- Status ValidBranchMapAndFrontier(
+ Status ValidateFrontier(
const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
const std::unordered_set<Node*>& frontier);
@@ -645,23 +663,11 @@
return branch_name[b];
}
-Status FunctionalizeCond::ValidBranchMapAndFrontier(
+Status FunctionalizeCond::ValidateFrontier(
const std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>&
branch_map,
const std::unordered_set<Node*>& frontier) {
std::unordered_set<const Node*> pending[kNumBranchTypes];
- for (const auto& kv : branch_map) {
- if (kv.second.count != kv.first->in_edges().size()) {
- return errors::FailedPrecondition("Value ", kv.first->DebugString(),
- " not dominated by switch nodes.");
- }
- if (VLOG_IS_ON(1)) {
- // Append attribute to the graph if running with logging to make the
- // changes clearer in the visualization.
- kv.first->AddAttr("_XlaFunctionalizeBranch",
- Branch_Name(kv.second.branch));
- }
- }
for (Node* n : frontier) {
pending[branch_map.at(n).branch].insert(n);
}
@@ -681,6 +687,10 @@
") in both Else and Then branch should be in Both.");
}
}
+ if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
+ pending[kElseBranch].empty()) {
+ return errors::Internal("Unexpected empty frontier for switch nodes");
+ }
return Status::OK();
}
@@ -707,7 +717,8 @@
return Status::OK();
}
-std::vector<Node*> FunctionalizeCond::DetermineSwitchOrder() {
+std::vector<FunctionalizeCond::PredicateSwitches>
+FunctionalizeCond::DeterminePredicateSwitchOrder() {
std::vector<Node*> dead_switches;
std::vector<Node*> switch_order;
DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) {
@@ -725,25 +736,12 @@
graph_->RemoveNode(n);
}
- return switch_order;
-}
-
-Status FunctionalizeCond::FunctionalizeInternal() {
- std::vector<Node*> switch_order = DetermineSwitchOrder();
- // If there are no switch nodes, then terminate.
+ std::vector<PredicateSwitches> predicate_switch_order;
if (switch_order.empty()) {
- return Status::OK();
+ return predicate_switch_order;
}
- struct PredicateSwitches {
- explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}
-
- Node* predicate;
- std::vector<Node*> switches;
- };
-
// Merge Switch nodes with common predicate.
- std::vector<PredicateSwitches> predicate_switch_order;
std::unordered_map<Node*, int> predicate_index;
// The nodes in switch_order are in reverse topological order, but the
// clustered switches need not be (i.e., when considered as a cluster one
@@ -759,71 +757,145 @@
}
predicate_switch_order[predicate_index[pred]].switches.push_back(*it);
}
+ return predicate_switch_order;
+}
+
+StatusOr<std::vector<Node*>>
+FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes(
+ const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
+ const std::vector<Node*>& switches) {
+ std::vector<Node*> old_control_nodes;
+ for (const auto& kv : branch_map) {
+ if (kv.second.count != kv.first->in_edges().size()) {
+ std::vector<const Edge*> delete_edges;
+ for (const Edge* in : kv.first->in_edges()) {
+ auto it = branch_map.find(in->src());
+ if (it == branch_map.end()) {
+ if (in->IsControlEdge()) {
+ old_control_nodes.push_back(in->src());
+ delete_edges.push_back(in);
+ } else {
+ if (IsSwitch(in->src())) {
+ if (std::find(switches.begin(), switches.end(), in->src()) ==
+ switches.end()) {
+ return errors::Internal(
+ "Unexpected switch node found during flow forward.");
+ }
+ continue;
+ }
+ return errors::InvalidArgument(
+ "Value ", kv.first->name(), "'s input, ", in->src()->name(),
+ ", is not dominated by switch nodes ", NodesToString(switches));
+ }
+ }
+ }
+ // Remove control edges from nodes that are not dominated by the switch
+ // nodes. New control dependencies will be added between these nodes and
+ // the XlaIf node inserted.
+ for (const Edge* e : delete_edges) {
+ graph_->RemoveEdge(e);
+ }
+ }
+ }
+ return old_control_nodes;
+}
+
+StatusOr<
+ std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
+ std::unordered_set<Node*>>>
+FunctionalizeCond::DetermineBranchMapAndFrontier(
+ const std::vector<Node*>& switches) {
+ std::unordered_map<Node*, ForwardFlowNode> branch_map;
+ std::unordered_set<Node*> frontier;
+ std::vector<Node*> stack = switches;
+ std::vector<bool> visited(graph_->num_node_ids(), false);
+ while (!stack.empty()) {
+ Node* n = stack.back();
+ stack.pop_back();
+
+ if (visited[n->id()]) {
+ continue;
+ }
+ visited[n->id()] = true;
+
+ // Propagate branch state along each edge of a switch node.
+ bool sink_only = true;
+ for (const Edge* e : n->out_edges()) {
+ Node* out = e->dst();
+ if (!out->IsOp()) {
+ continue;
+ }
+ sink_only = false;
+ // Propagate branch information.
+ ForwardFlowNode& ffn = branch_map[out];
+ if (IsSwitch(n)) {
+ int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
+ TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn));
+ } else {
+ TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn));
+ }
+ if (IsMerge(out)) {
+ if (out->in_edges().size() == ffn.count) {
+ frontier.insert(out);
+ }
+ } else if (!visited[out->id()]) {
+ stack.push_back(out);
+ }
+ }
+ if (sink_only) {
+ if (!IsIdentity(n)) {
+ VLOG(1) << "Feeding into sink: " << n->DebugString();
+ }
+ }
+ }
+
+ if (VLOG_IS_ON(2)) {
+ for (const auto& kv : branch_map) {
+ // Append attribute to the graph if running with logging to make the
+ // changes clearer in the visualization.
+ kv.first->AddAttr("_XlaFunctionalizeBranch",
+ Branch_Name(kv.second.branch));
+ }
+ }
+ return std::make_pair(std::move(branch_map), std::move(frontier));
+}
+
+Status FunctionalizeCond::FunctionalizeInternal() {
+ std::vector<PredicateSwitches> predicate_switch_order =
+ DeterminePredicateSwitchOrder();
// Iterate from innermost set of clustered switches to outermost, replacing
// matching switch->merge subgraphs with single XlaIf nodes.
for (auto it = predicate_switch_order.rbegin();
it != predicate_switch_order.rend(); ++it) {
auto& ps = *it;
- VLOG(3) << "Flow down from: " << ps.predicate->name() << " -> "
- << NodesToString(ps.switches);
+ VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " ("
+ << ps.predicate->name() << ")";
std::unordered_map<Node*, ForwardFlowNode> branch_map;
std::unordered_set<Node*> frontier;
+ TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
+ DetermineBranchMapAndFrontier(ps.switches));
- std::vector<Node*> stack = ps.switches;
- std::vector<bool> visited(graph_->num_node_ids(), false);
- while (!stack.empty()) {
- Node* n = stack.back();
- stack.pop_back();
-
- if (visited[n->id()]) {
- continue;
- }
- visited[n->id()] = true;
-
- // Propagate branch state along each edge of a switch node.
- bool sink_only = true;
- for (const Edge* e : n->out_edges()) {
- Node* out = e->dst();
- if (!out->IsOp()) {
- continue;
- }
- sink_only = false;
- // Propagate branch information.
- ForwardFlowNode& ffn = branch_map[out];
- if (IsSwitch(n)) {
- int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
- TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn));
- } else {
- TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn));
- }
- if (IsMerge(out)) {
- if (out->in_edges().size() == ffn.count) {
- frontier.insert(out);
- }
- } else if (!visited[out->id()] && ffn.count == out->in_edges().size()) {
- // If all predecessors are dominated by the switch nodes, then add
- // the output to the stack.
- stack.push_back(out);
- }
- }
- if (sink_only) {
- if (!IsIdentity(n)) {
- VLOG(1) << "Feeding into sink: " << n->DebugString();
- }
- }
- }
-
- TF_RETURN_IF_ERROR(ValidBranchMapAndFrontier(branch_map, frontier));
VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): "
<< dump_graph::DumpGraphToFile("functionalize_bc", *graph_);
+ TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
+
std::vector<Node*> switch_nodes(ps.switches);
- std::sort(switch_nodes.begin(), switch_nodes.end(), CondCmp());
+ std::sort(switch_nodes.begin(), switch_nodes.end(), NodeCmp());
std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
- std::sort(merge_nodes.begin(), merge_nodes.end(), CondCmp());
- TF_RETURN_IF_ERROR(ConvertCorrespondingMergeToXlaIf(
- switch_nodes, merge_nodes, ps.predicate));
+ std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
+ TF_ASSIGN_OR_RETURN(std::vector<Node*> old_control_nodes,
+ EnsureDominanceAndReturnNonDominatedControlNodes(
+ branch_map, ps.switches));
+
+ TF_ASSIGN_OR_RETURN(
+ Node * if_node,
+ ConvertToXlaIf(switch_nodes, merge_nodes, ps.predicate));
+ for (Node* old : old_control_nodes) {
+ graph_->AddControlEdge(old, if_node);
+ }
+
for (auto& del_kv : branch_map) {
graph_->RemoveNode(del_kv.first);
}
@@ -836,7 +908,7 @@
return Status::OK();
}
-xla::StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
+StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
const std::vector<Node*>& switch_nodes,
const std::vector<Node*>& merge_nodes, Node* predicate) {
VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input "
@@ -917,7 +989,7 @@
}
std::vector<Node*> stack;
- stack.reserve(switch_nodes.size());
+ stack.reserve(merge_nodes.size());
for (int j = 0; j < merge_nodes.size(); ++j) {
Node* node = merge_nodes[j];
TF_ASSIGN_OR_RETURN(node_map.at(node->id()),
@@ -988,10 +1060,10 @@
return Status::OK();
}
-Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf(
+StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
const std::vector<Node*>& switch_nodes,
const std::vector<Node*>& merge_nodes, Node* predicate) {
- VLOG(1) << "ConvertMergeToXlaIf for " << NodesToString(switch_nodes) << " -> "
+ VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> "
<< NodesToString(merge_nodes);
// Extract bodies and builds a If operator.
@@ -1000,7 +1072,7 @@
TF_RETURN_IF_ERROR(AddInputEdges(switch_nodes, predicate, if_node));
TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
- return Status::OK();
+ return if_node;
}
Status FunctionalizeCond::Functionalize(Graph* graph,
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index f373684..523169f 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -318,4 +318,8 @@
return std::move(literal);
}
+StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
+ return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 3ca0d2e..19fd14f 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -176,6 +176,13 @@
StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
const Shape& shape, int device_ordinal);
+ // Returns the device ordinal that corresponds to the given replica number.
+ //
+ // This returns an error if there is not a one-to-one correspondence of
+ // replicas to device ordinals, but is useful as a short term mechanism for
+ // the "easy" case where a single replica is a single device.
+ StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
+
// Returns the platform that the underlying service targets.
perftools::gputools::Platform* platform() const;
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 03e403b..2b87cb0 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -21,9 +21,57 @@
namespace swig {
-void TransferToInfeedLocal(const Literal& literal) {
- LocalClient* client = ClientLibrary::LocalClientOrDie();
- TF_CHECK_OK(client->TransferToInfeedLocal(literal, /*device_ordinal=*/0));
+// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of
+// device handles instead of needing to set the number of replicas at XLA
+// service initialization time.
+tensorflow::mutex g_replica_count_mutex(tensorflow::LINKER_INITIALIZED);
+int g_replica_count = 1;
+bool g_local_client_created = false;
+
+Status InitializeReplicaCount(int replica_count) {
+ if (replica_count < 1) {
+ return InvalidArgument("Replica count must be >= 1; got %d.",
+ replica_count);
+ }
+ tensorflow::mutex_lock lock(g_replica_count_mutex);
+ if (g_local_client_created) {
+ return FailedPrecondition(
+ "Attempted to set the replica count to %d, but a local XLA service was "
+ "previously created with a replica count of %d.",
+ replica_count, g_replica_count);
+ }
+ g_replica_count = replica_count;
+ return Status::OK();
+}
+
+int GetReplicaCount() {
+ tensorflow::mutex_lock lock(g_replica_count_mutex);
+ return g_replica_count;
+}
+
+LocalClient* GetOrCreateLocalClient() {
+ LocalClientOptions options;
+ {
+ tensorflow::mutex_lock lock(g_replica_count_mutex);
+ options.set_number_of_replicas(g_replica_count);
+ g_local_client_created = true;
+ }
+ return ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie();
+}
+
+Status TransferToInfeedLocal(const Literal& literal) {
+ VLOG(1) << "Infeeding literal without replica number.";
+ LocalClient* client = GetOrCreateLocalClient();
+ return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0);
+}
+
+Status TransferToInfeedLocalReplica(const Literal& literal,
+ int replica_number) {
+ VLOG(1) << "Infeeding literal to replica number: " << replica_number;
+ LocalClient* client = GetOrCreateLocalClient();
+ TF_ASSIGN_OR_RETURN(int device_ordinal,
+ client->ReplicaNumberToDeviceOrdinal(replica_number));
+ return client->TransferToInfeedLocal(literal, device_ordinal);
}
LocalShapedBuffer::LocalShapedBuffer(
@@ -37,7 +85,7 @@
/* static */
LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) {
- LocalClient* client = ClientLibrary::LocalClientOrDie();
+ LocalClient* client = GetOrCreateLocalClient();
std::unique_ptr<ScopedShapedBuffer> buf =
client
->LiteralToShapedBuffer(argument,
@@ -48,7 +96,7 @@
}
std::unique_ptr<Literal> LocalShapedBuffer::ToLiteral() const {
- LocalClient* client = ClientLibrary::LocalClientOrDie();
+ LocalClient* client = GetOrCreateLocalClient();
return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie();
}
@@ -56,43 +104,95 @@
std::unique_ptr<LocalExecutable> executable)
: executable_(std::move(executable)) {}
-std::unique_ptr<Literal> CompiledLocalComputation::Execute(
+StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
const std::vector<Literal>& arguments) {
- LocalClient* client = ClientLibrary::LocalClientOrDie();
+ LocalClient* client = GetOrCreateLocalClient();
- // Transfer arguments in
- std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
- scoped_buffers.reserve(arguments.size());
- for (const Literal& argument : arguments) {
- scoped_buffers.push_back(
- client
- ->LiteralToShapedBuffer(argument,
- /*device_ordinal=*/0,
- client->backend().memory_allocator())
- .ConsumeValueOrDie());
+ // Each replica populates a StatusOr result, but only replica zero actually
+ // retrieves its literal value.
+ std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
+ {
+ tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
+ GetReplicaCount());
+
+ for (int replica = 0; replica < GetReplicaCount(); ++replica) {
+ pool.Schedule([this, client, replica, &arguments, &results] {
+ StatusOr<int> device_ordinal_status =
+ client->ReplicaNumberToDeviceOrdinal(replica);
+ if (!device_ordinal_status.ok()) {
+ results[replica] = device_ordinal_status.status();
+ return;
+ }
+ const int device_ordinal = device_ordinal_status.ValueOrDie();
+ VLOG(3) << "Replica " << replica
+ << " mapped to device ordinal for execution: "
+ << device_ordinal;
+ // Transfer arguments in
+ std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
+ scoped_buffers.reserve(arguments.size());
+ for (const Literal& argument : arguments) {
+ StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed =
+ client->LiteralToShapedBuffer(
+ argument, device_ordinal,
+ client->backend().memory_allocator());
+ if (!pushed.ok()) {
+ results[replica] = pushed.status();
+ return;
+ }
+ scoped_buffers.push_back(std::move(pushed).ValueOrDie());
+ }
+
+ // Execute
+ std::vector<const ShapedBuffer*> argument_buffers;
+ argument_buffers.reserve(scoped_buffers.size());
+ for (auto& buffer : scoped_buffers) {
+ argument_buffers.push_back(buffer.get());
+ }
+
+ DeviceAssignment device_assignment =
+ client->backend()
+ .computation_placer()
+ ->AssignDevices(GetReplicaCount(), /*computation_count=*/1)
+ .ConsumeValueOrDie();
+
+ ExecutableRunOptions options;
+ options.set_device_ordinal(device_ordinal);
+ options.set_allocator(client->backend().memory_allocator());
+ options.set_inter_op_thread_pool(
+ client->backend().inter_op_thread_pool());
+ options.set_intra_op_thread_pool(
+ client->backend().eigen_intra_op_thread_pool_device());
+ options.set_device_assignment(&device_assignment);
+ StatusOr<std::unique_ptr<ScopedShapedBuffer>> result_buffer_status =
+ executable_->Run(argument_buffers, options);
+ if (!result_buffer_status.ok()) {
+ results[replica] = result_buffer_status.status();
+ return;
+ }
+
+ // Transfer result out
+ results[replica] =
+ client->ShapedBufferToLiteral(*result_buffer_status.ValueOrDie());
+ });
+ }
}
- // Execute
- std::vector<const ShapedBuffer*> argument_buffers;
- argument_buffers.reserve(scoped_buffers.size());
- for (auto& buffer : scoped_buffers) {
- argument_buffers.push_back(buffer.get());
+ for (int replica = 0; replica < GetReplicaCount(); ++replica) {
+ const auto& statusor = results[replica];
+ if (!statusor.ok()) {
+ return InternalError(
+ "Failed running replica %d (other replicas may have failed as well): "
+ "%s.",
+ replica, statusor.status().ToString().c_str());
+ }
}
- ExecutableRunOptions options;
- options.set_allocator(client->backend().memory_allocator());
- options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool());
- options.set_intra_op_thread_pool(
- client->backend().eigen_intra_op_thread_pool_device());
- std::unique_ptr<ScopedShapedBuffer> result_buffer =
- executable_->Run(argument_buffers, options).ConsumeValueOrDie();
- // Transfer result out
- return client->ShapedBufferToLiteral(*result_buffer).ConsumeValueOrDie();
+ return std::move(results[0]);
}
LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers(
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles) {
- LocalClient* client = ClientLibrary::LocalClientOrDie();
+ LocalClient* client = GetOrCreateLocalClient();
std::vector<const ShapedBuffer*> argument_buffers;
argument_buffers.reserve(argument_handles.size());
@@ -123,7 +223,7 @@
argument_shape_pointers.push_back(&argument_shape);
}
- LocalClient* client = ClientLibrary::LocalClientOrDie();
+ LocalClient* client = GetOrCreateLocalClient();
ExecutableBuildOptions options;
TF_ASSIGN_OR_RETURN(
auto local_executable,
@@ -136,7 +236,7 @@
}
LocalComputationBuilder::LocalComputationBuilder(const string& computation_name)
- : builder_(ClientLibrary::LocalClientOrDie(), computation_name) {}
+ : builder_(GetOrCreateLocalClient(), computation_name) {}
StatusOr<LocalComputation*> LocalComputationBuilder::Build() {
TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build());
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 8da4c99..1104d7f 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -27,11 +27,25 @@
namespace swig {
-// Wraps the local client's infeed-transfer function, aborting on error.
+// Initializes the number of replicas that XLA will be initialized with (when
+// first obtaining a handle to the local XLA service). If this is called after
+// the handle to the local XLA service has been established, then an error is
+// returned.
+Status InitializeReplicaCount(int replica_count);
+
+// Returns the replica count that is currently set, regardless of whether the
+// local XLA service has been instantiated yet or not.
+int GetReplicaCount();
+
+// Wraps the local client's infeed-transfer function.
//
-// TODO(leary) ideally we could return a value that would permit an appropriate
-// Python exception to be raised.
-void TransferToInfeedLocal(const Literal& literal);
+// The default device ordinal (0) is used.
+Status TransferToInfeedLocal(const Literal& literal);
+
+// Transfers the given literal to the infeed of the given replica.
+//
+// The replica number is resolved to an appropriate device ordinal.
+Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
// Wraps a ScopedShapedBuffer produced by copying a literal "to
// device," i.e. copying a literal to a scoped buffer via the local
@@ -56,7 +70,8 @@
class CompiledLocalComputation {
public:
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
- std::unique_ptr<Literal> Execute(const std::vector<Literal>& arguments);
+ StatusOr<std::unique_ptr<Literal> > Execute(
+ const std::vector<Literal>& arguments);
LocalShapedBuffer* ExecuteWithShapedBuffers(
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 40089b8..8b4779a 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -188,6 +188,15 @@
}
}
+%typemap(out) Status {
+ if (!$1.ok()) {
+ PyErr_SetString(
+ PyExc_RuntimeError, $1.ToString().c_str());
+ return NULL;
+ }
+ $result = Py_None;
+}
+
// ArraySlice<int64>
%typemap(in) tensorflow::gtl::ArraySlice<int64>
@@ -286,6 +295,14 @@
$result = numpy::PyObjectFromXlaLiteral(*$1);
}
+%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+ if (!$1.ok()) {
+ PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
+ return NULL;
+ }
+ $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
+}
+
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -540,7 +557,10 @@
%ignoreall
%unignore xla;
%unignore xla::swig;
+%unignore xla::swig::InitializeReplicaCount;
+%unignore xla::swig::GetReplicaCount;
%unignore xla::swig::TransferToInfeedLocal;
+%unignore xla::swig::TransferToInfeedLocalReplica;
%unignore xla::swig::LocalShapedBuffer;
%unignore xla::swig::LocalShapedBuffer::FromLiteral;
%unignore xla::swig::LocalShapedBuffer::ToLiteral;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 96e2401..fead7d6 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -209,20 +209,24 @@
return np.require(value, requirements=['C', 'A'])
-def transfer_to_infeed(value):
+def transfer_to_infeed(value, replica_number=None):
"""Transfers the given value into the XLA infeed queue.
XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
a totally ordered stream of values. This is dequeued from XLA computations via
the Infeed() operation.
- TODO(leary): this currently implicitly enqueues to device ordinal 0.
-
Args:
value: the value that the caller would like to enqueue into the XLA infeed
queue
+ replica_number: the replica number to infeed the value to -- if not
+ provided, then the default replica (trivially replica 0) is used.
"""
- c_api.TransferToInfeedLocal(require_numpy_array_layout(value))
+ if replica_number is None:
+ c_api.TransferToInfeedLocal(require_numpy_array_layout(value))
+ else:
+ c_api.TransferToInfeedLocalReplica(
+ require_numpy_array_layout(value), replica_number)
class LocalComputation(object):
@@ -832,3 +836,25 @@
_forward_methods_to_local_builder()
+
+
+def initialize_replica_count(replica_count):
+ """Initializes the desired replica count to use on XLA service init.
+
+ Args:
+ replica_count: number of replicas that are desired for set up during XLA
+ initalization.
+
+ Raises:
+ A runtime exception if the XLA service has already been initialized.
+ """
+ c_api.InitializeReplicaCount(replica_count)
+
+
+def get_replica_count():
+ """Returns the current replica count used for the XLA service.
+
+ Note: this will return a value whether the XLA service has been initialized
+ yet or not.
+ """
+ return c_api.GetReplicaCount()
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 173f0e2..ddd75bb 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -467,6 +467,27 @@
return HandleSign<ReturnT>(sign);
}
+ template <typename NativeT, typename std::enable_if<std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleAtan2(HloInstruction* atan2) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2],
+ ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem,
+ ElementwiseT rhs_elem) {
+ return std::atan2(lhs_elem, rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT, typename std::enable_if<!std::is_floating_point<
+ NativeT>::value>::type* = nullptr>
+ Status HandleAtan2(HloInstruction* atan2) {
+ return InvalidArgument("Unsupported type for Atan2");
+ }
+
+ Status HandleAtan2(HloInstruction* atan2) override {
+ return HandleAtan2<ElementwiseT>(atan2);
+ }
+
Status HandleTanh(HloInstruction* tanh) override {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index c5cab92..ff3dbdd 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -25,6 +25,7 @@
#include <iosfwd>
#include <list>
#include <memory>
+#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
@@ -1451,6 +1452,10 @@
using ConstHloInstructionMap =
std::map<const HloInstruction*, ValueT, HloPtrComparator>;
+using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
+using ConstHloInstructionSet =
+ std::set<const HloInstruction*, HloPtrComparator>;
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 7b3a8ce..204a8bf 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -169,8 +169,16 @@
std::unique_ptr<ScopedShapedBuffer> scoped_result,
ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator()));
- return backend().transfer_manager()->TransferLiteralFromDevice(
+ auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice(
stream.parent(), *scoped_result);
+ if (result_literal.ok()) {
+ VLOG(4) << "Executed binary and got result: "
+ << result_literal.ValueOrDie()->ToString();
+ } else {
+ VLOG(4) << "Executed binary and got status: "
+ << result_literal.status().ToString();
+ }
+ return result_literal;
}
Backend& HloRunner::backend() {
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 4071b94..d5715aa 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -122,4 +122,10 @@
execute_backend_.get(), executor);
}
+StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
+ return backend().computation_placer()->DeviceId(
+ replica_number, /*computation=*/0, options_.number_of_replicas(),
+ /*computation_count=*/1);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 52c4346..acbc726 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -47,6 +47,13 @@
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const Shape* result_layout, int device_ordinal);
+ // Returns the device ordinal that corresponds to the given replica number.
+ //
+ // This returns an error if there is not a one-to-one correspondence of
+ // replicas to device ordinals, but is useful as a short term mechanism for
+ // the "easy" case where a single replica is a single device.
+ StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
+
private:
explicit LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> backend);
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 5ca6ccb..c3743b1 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -199,12 +199,10 @@
StatusOr<const LogicalBuffer*> GetBufferDefinedAt(
const HloInstruction* instruction, const ShapeIndex& index) const;
- // Return a vector containing all BufferAliases of the given logical buffer
- // This trivially includes the BufferAlias with same instruction and index as
- // the logical buffer itself, so the returned vector is never empty. The
- // buffer alias set is the inverse of the points-to set. That is,
- // LogicalBuffer B is in the points-to set of instruction I at index N iff
- // instruction I, index N is a BufferAlias of B.
+ // Return a (possibly empty) vector containing all BufferAliases of the given
+ // logical buffer The buffer alias set is the inverse of the points-to set.
+ // That is, LogicalBuffer B is in the points-to set of instruction I at index
+ // N iff instruction I, index N is a BufferAlias of B.
using BufferAliasVector = tensorflow::gtl::InlinedVector<BufferAlias, 1>;
const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const;
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index fb0e6f7..1073cc7 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -306,6 +306,11 @@
return false;
}
+ if (while_body_root->opcode() != HloOpcode::kTuple) {
+ VLOG(2) << "While body's root is not a tuple(...) instruction.";
+ return false;
+ }
+
auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
// Bail if param0 of while_cond or while_body has users which aren't of type
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index d99b31d..c5183f8 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -418,5 +418,32 @@
op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
}
+TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) {
+ auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
+ Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+
+ HloComputation* while_body = [&]() {
+ HloComputation::Builder builder(TestName() + ".passthrough");
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, while_shape, "param"));
+ HloComputation* result = module().AddEmbeddedComputation(builder.Build());
+
+ result->AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+ return result;
+ }();
+
+ HloComputation::Builder builder(TestName());
+ auto* init_value = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, while_shape, "init_value"));
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
+ while_body, init_value));
+ module().AddEntryComputation(builder.Build());
+ TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
+ WhileLoopSimplifier{}.Run(&module()));
+ EXPECT_FALSE(simplified_loop);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index c6e8b24..935b94c 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -1971,6 +1971,18 @@
error_spec_);
}
+XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f});
+ auto b = builder.ConstantR1<float>({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f});
+ auto atan = builder.Atan2(a, b);
+
+ ComputeAndCompareR1<float>(
+ &builder,
+ {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f},
+ {}, error_spec_);
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index e8a05cf..bb215be 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -29,7 +29,7 @@
CHECK_EQ(literal->shape().element_type(),
primitive_util::NativeToPrimitiveType<FloatT>());
std::minstd_rand0 engine;
- std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
+ std::uniform_real_distribution<FloatT> generator(-0.9f, 1.0f);
TF_CHECK_OK(literal->Populate<FloatT>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return generator(engine);
@@ -42,7 +42,7 @@
void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) {
CHECK_EQ(literal->shape().element_type(), BF16);
std::minstd_rand0 engine;
- std::uniform_real_distribution<float> generator(0.0f, 1.0f);
+ std::uniform_real_distribution<float> generator(-0.9f, 1.0f);
TF_CHECK_OK(literal->Populate<bfloat16>(
[&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
return static_cast<bfloat16>(generator(engine));
@@ -126,6 +126,11 @@
fused_uses.end());
} else if (NeedsZeroInitValue(use)) {
constrained_uses.push_back(instruction);
+ } else if (opcode == HloOpcode::kConvert ||
+ opcode == HloOpcode::kReducePrecision) {
+ auto converted_uses = FindConstrainedUses(dataflow, *instruction);
+ constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
+ converted_uses.end());
}
}
}
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 742835c..c9a3537 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -83,8 +83,11 @@
return Status::OK();
}
AttrValueMap attr_values = func_.attr();
- return lib_->Instantiate(func_.name(), AttrSlice(&attr_values),
- {target_device_}, &handle_);
+ AttrValue v;
+ v.set_s(target_device_);
+ AddAttr("_target", v, &attr_values);
+
+ return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_);
}
// Returns true if we've got to the end of the sequence and exhausted the
diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py
index bfc7274..15b0820 100644
--- a/tensorflow/contrib/distributions/python/ops/test_util.py
+++ b/tensorflow/contrib/distributions/python/ops/test_util.py
@@ -327,7 +327,7 @@
num_samples=int(1e5),
seed=24,
rtol=1e-2,
- atol=0.,
+ atol=0.1,
cov_rtol=None,
cov_atol=None):
"""Tests that sample/mean/covariance are consistent with each other.
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 6aef010..15a2188 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -6,6 +6,7 @@
py_library(
name = "examples_pip",
deps = [
+ "//tensorflow/contrib/eager/python/examples/gan:mnist",
"//tensorflow/contrib/eager/python/examples/linear_regression",
"//tensorflow/contrib/eager/python/examples/mnist",
"//tensorflow/contrib/eager/python/examples/resnet50",
diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD
new file mode 100644
index 0000000..c61ec2d
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/BUILD
@@ -0,0 +1,36 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+py_binary(
+ name = "mnist",
+ srcs = ["mnist.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
+
+cuda_py_test(
+ name = "mnist_test",
+ srcs = ["mnist_test.py"],
+ additional_deps = [
+ ":mnist",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "mnist_graph_test",
+ srcs = ["mnist_graph_test.py"],
+ additional_deps = [
+ ":mnist",
+ "//third_party/py/numpy",
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/gan/README.md b/tensorflow/contrib/eager/python/examples/gan/README.md
new file mode 100644
index 0000000..e8c9db1
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/README.md
@@ -0,0 +1,38 @@
+# GAN with TensorFlow eager execution
+
+A simple Generative Adversarial Network (GAN) example using eager execution.
+The discriminator and generator networks each contain a few convolution and
+fully connected layers.
+
+Other eager execution examples can be found under the parent directory.
+
+## Content
+
+- `mnist.py`: Model definitions and training routines.
+- `mnist_test.py`: Benchmarks for training and using the models using eager
+execution.
+- `mnist_graph_test.py`: Benchmarks for trainig and using the models using
+graph execution. The same model definitions and loss functions are used in
+all benchmarks.
+
+
+## To run
+
+- Make sure you have installed TensorFlow 1.5+ or the latest `tf-nightly`
+or `tf-nightly-gpu` pip package in order to access the eager execution feature.
+
+- Train model. E.g.,
+
+ ```bash
+ python mnist.py
+ ```
+
+ Use `--output_dir=<DIR>` to direct the script to save TensorBoard summaries
+ during training. Disabled by default.
+
+ Use `--checkpoint_dir=<DIR>` to direct the script to save checkpoints to
+ `<DIR>` during training. DIR defaults to /tmp/tensorflow/mnist/checkpoints/.
+ The script will load the latest saved checkpoint from this directory if
+ one exists.
+
+ Use `-h` for other options.
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
new file mode 100644
index 0000000..b9ac79f
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -0,0 +1,368 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A deep MNIST classifier using convolutional layers.
+
+Sample usage:
+ python mnist.py --help
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+import time
+
+import tensorflow as tf
+
+import tensorflow.contrib.eager as tfe
+from tensorflow.examples.tutorials.mnist import input_data
+
+FLAGS = None
+
+
+class Discriminator(tfe.Network):
+ """GAN Discriminator.
+
+ A network to differentiate between generated and real handwritten digits.
+ """
+
+ def __init__(self, data_format):
+ """Creates a model for discriminating between real and generated digits.
+
+ Args:
+ data_format: Either 'channels_first' or 'channels_last'.
+ 'channels_first' is typically faster on GPUs while 'channels_last' is
+ typically faster on CPUs. See
+ https://www.tensorflow.org/performance/performance_guide#data_formats
+ """
+ super(Discriminator, self).__init__(name='')
+ if data_format == 'channels_first':
+ self._input_shape = [-1, 1, 28, 28]
+ else:
+ assert data_format == 'channels_last'
+ self._input_shape = [-1, 28, 28, 1]
+ self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME',
+ data_format=data_format,
+ activation=tf.tanh))
+ self.pool1 = self.track_layer(
+ tf.layers.AveragePooling2D(2, 2, data_format=data_format))
+ self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5,
+ data_format=data_format,
+ activation=tf.tanh))
+ self.pool2 = self.track_layer(
+ tf.layers.AveragePooling2D(2, 2, data_format=data_format))
+ self.flatten = self.track_layer(tf.layers.Flatten())
+ self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh))
+ self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None))
+
+ def call(self, inputs):
+ """Return two logits per image estimating input authenticity.
+
+ Users should invoke __call__ to run the network, which delegates to this
+ method (and not call this method directly).
+
+ Args:
+ inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1]
+ or [batch_size, 1, 28, 28]
+
+ Returns:
+ A Tensor with shape [batch_size] containing logits estimating
+ the probability that corresponding digit is real.
+ """
+ x = tf.reshape(inputs, self._input_shape)
+ x = self.conv1(x)
+ x = self.pool1(x)
+ x = self.conv2(x)
+ x = self.pool2(x)
+ x = self.flatten(x)
+ x = self.fc1(x)
+ x = self.fc2(x)
+ return x
+
+
+class Generator(tfe.Network):
+ """Generator of handwritten digits similar to the ones in the MNIST dataset.
+ """
+
+ def __init__(self, data_format):
+ """Creates a model for discriminating between real and generated digits.
+
+ Args:
+ data_format: Either 'channels_first' or 'channels_last'.
+ 'channels_first' is typically faster on GPUs while 'channels_last' is
+ typically faster on CPUs. See
+ https://www.tensorflow.org/performance/performance_guide#data_formats
+ """
+ super(Generator, self).__init__(name='')
+ self.data_format = data_format
+ # We are using 128 6x6 channels as input to the first deconvolution layer
+ if data_format == 'channels_first':
+ self._pre_conv_shape = [-1, 128, 6, 6]
+ else:
+ assert data_format == 'channels_last'
+ self._pre_conv_shape = [-1, 6, 6, 128]
+ self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128,
+ activation=tf.tanh))
+
+ # In call(), we reshape the output of fc1 to _pre_conv_shape
+
+ # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
+ self.conv1 = self.track_layer(tf.layers.Conv2DTranspose(
+ 64, 4, strides=2, activation=None, data_format=data_format))
+
+ # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
+ self.conv2 = self.track_layer(tf.layers.Conv2DTranspose(
+ 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format))
+
+ def call(self, inputs):
+ """Return a batch of generated images.
+
+ Users should invoke __call__ to run the network, which delegates to this
+ method (and not call this method directly).
+
+ Args:
+ inputs: A batch of noise vectors as a Tensor with shape
+ [batch_size, length of noise vectors].
+
+ Returns:
+ A Tensor containing generated images. If data_format is 'channels_last',
+ the shape of returned images is [batch_size, 28, 28, 1], else
+ [batch_size, 1, 28, 28]
+ """
+
+ x = self.fc1(inputs)
+ x = tf.reshape(x, shape=self._pre_conv_shape)
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x
+
+
+def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
+ """Original discriminator loss for GANs, with label smoothing.
+
+ See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
+ details.
+
+ Args:
+ discriminator_real_outputs: Discriminator output on real data.
+ discriminator_gen_outputs: Discriminator output on generated data. Expected
+ to be in the range of (-inf, inf).
+
+ Returns:
+ A scalar loss Tensor.
+ """
+
+ loss_on_real = tf.losses.sigmoid_cross_entropy(
+ tf.ones_like(discriminator_real_outputs), discriminator_real_outputs,
+ label_smoothing=0.25)
+ loss_on_generated = tf.losses.sigmoid_cross_entropy(
+ tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
+ loss = loss_on_real + loss_on_generated
+ tf.contrib.summary.scalar('discriminator_loss', loss)
+ return loss
+
+
+def generator_loss(discriminator_gen_outputs):
+ """Original generator loss for GANs.
+
+ L = -log(sigmoid(D(G(z))))
+
+ See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661)
+ for more details.
+
+ Args:
+ discriminator_gen_outputs: Discriminator output on generated data. Expected
+ to be in the range of (-inf, inf).
+
+ Returns:
+ A scalar loss Tensor.
+ """
+ loss = tf.losses.sigmoid_cross_entropy(
+ tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
+ tf.contrib.summary.scalar('generator_loss', loss)
+ return loss
+
+
+def train_one_epoch(generator, discriminator,
+ generator_optimizer, discriminator_optimizer,
+ dataset, log_interval, noise_dim):
+ """Trains `generator` and `discriminator` models on `dataset`.
+
+ Args:
+ generator: Generator model.
+ discriminator: Discriminator model.
+ generator_optimizer: Optimizer to use for generator.
+ discriminator_optimizer: Optimizer to use for discriminator.
+ dataset: Dataset of images to train on.
+ log_interval: How many global steps to wait between logging and collecting
+ summaries.
+ noise_dim: Dimension of noise vector to use.
+ """
+
+ total_generator_loss = 0.0
+ total_discriminator_loss = 0.0
+ for (batch_index, images) in enumerate(tfe.Iterator(dataset)):
+ with tf.device('/cpu:0'):
+ tf.assign_add(tf.train.get_global_step(), 1)
+
+ with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
+ current_batch_size = images.shape[0]
+ noise = tf.random_uniform(shape=[current_batch_size, noise_dim],
+ minval=-1., maxval=1., seed=batch_index)
+
+ with tfe.GradientTape(persistent=True) as g:
+ generated_images = generator(noise)
+ tf.contrib.summary.image('generated_images',
+ tf.reshape(generated_images, [-1, 28, 28, 1]),
+ max_images=10)
+
+ discriminator_gen_outputs = discriminator(generated_images)
+ discriminator_real_outputs = discriminator(images)
+ discriminator_loss_val = discriminator_loss(discriminator_real_outputs,
+ discriminator_gen_outputs)
+ total_discriminator_loss += discriminator_loss_val
+
+ generator_loss_val = generator_loss(discriminator_gen_outputs)
+ total_generator_loss += generator_loss_val
+
+ generator_grad = g.gradient(generator_loss_val, generator.variables)
+ discriminator_grad = g.gradient(discriminator_loss_val,
+ discriminator.variables)
+
+ with tf.variable_scope('generator'):
+ generator_optimizer.apply_gradients(zip(generator_grad,
+ generator.variables))
+ with tf.variable_scope('discriminator'):
+ discriminator_optimizer.apply_gradients(zip(discriminator_grad,
+ discriminator.variables))
+
+ if log_interval and batch_index > 0 and batch_index % log_interval == 0:
+ print('Batch #%d\tAverage Generator Loss: %.6f\t'
+ 'Average Discriminator Loss: %.6f' % (
+ batch_index, total_generator_loss/batch_index,
+ total_discriminator_loss/batch_index))
+
+
+def main(_):
+ (device, data_format) = ('/gpu:0', 'channels_first')
+ if FLAGS.no_gpu or tfe.num_gpus() <= 0:
+ (device, data_format) = ('/cpu:0', 'channels_last')
+ print('Using device %s, and data format %s.' % (device, data_format))
+
+ # Load the datasets
+ data = input_data.read_data_sets(FLAGS.data_dir)
+ dataset = (tf.data.Dataset
+ .from_tensor_slices(data.train.images)
+ .shuffle(60000)
+ .batch(FLAGS.batch_size))
+
+ # Create the models and optimizers
+ generator = Generator(data_format)
+ discriminator = Discriminator(data_format)
+ with tf.variable_scope('generator'):
+ generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
+ with tf.variable_scope('discriminator'):
+ discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
+
+ # Prepare summary writer and checkpoint info
+ summary_writer = tf.contrib.summary.create_summary_file_writer(
+ FLAGS.output_dir, flush_millis=1000)
+ checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
+ latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
+ if latest_cpkt:
+ print('Using latest checkpoint at ' + latest_cpkt)
+
+ with tf.device(device):
+ for epoch in range(1, 101):
+ with tfe.restore_variables_on_create(latest_cpkt):
+ global_step = tf.train.get_or_create_global_step()
+ start = time.time()
+ with summary_writer.as_default():
+ train_one_epoch(generator, discriminator, generator_optimizer,
+ discriminator_optimizer,
+ dataset, FLAGS.log_interval, FLAGS.noise)
+ end = time.time()
+ print('\nTrain time for epoch #%d (global step %d): %f' % (
+ epoch, global_step.numpy(), end - start))
+
+ all_variables = (
+ generator.variables
+ + discriminator.variables
+ + generator_optimizer.variables()
+ + discriminator_optimizer.variables()
+ + [global_step])
+ tfe.Saver(all_variables).save(
+ checkpoint_prefix, global_step=global_step)
+
+
+if __name__ == '__main__':
+ tfe.enable_eager_execution()
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--data-dir',
+ type=str,
+ default='/tmp/tensorflow/mnist/input_data',
+ help=('Directory for storing input data (default '
+ '/tmp/tensorflow/mnist/input_data)'))
+ parser.add_argument(
+ '--batch-size',
+ type=int,
+ default=128,
+ metavar='N',
+ help='input batch size for training (default: 128)')
+ parser.add_argument(
+ '--log-interval',
+ type=int,
+ default=100,
+ metavar='N',
+ help=('number of batches between logging and writing summaries '
+ '(default: 100)'))
+ parser.add_argument(
+ '--output_dir',
+ type=str,
+ default=None,
+ metavar='DIR',
+ help='Directory to write TensorBoard summaries (defaults to none)')
+ parser.add_argument(
+ '--checkpoint_dir',
+ type=str,
+ default='/tmp/tensorflow/mnist/checkpoints/',
+ metavar='DIR',
+ help=('Directory to save checkpoints in (once per epoch) (default '
+ '/tmp/tensorflow/mnist/checkpoints/)'))
+ parser.add_argument(
+ '--lr',
+ type=float,
+ default=0.001,
+ metavar='LR',
+ help='learning rate (default: 0.001)')
+ parser.add_argument(
+ '--noise',
+ type=int,
+ default=100,
+ metavar='N',
+ help='Length of noise vector for generator input (default: 100)')
+ parser.add_argument(
+ '--no-gpu',
+ action='store_true',
+ default=False,
+ help='disables GPU usage even if a GPU is available')
+
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py
new file mode 100644
index 0000000..12b39b0
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py
@@ -0,0 +1,151 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.eager.python.examples.gan import mnist
+
+NOISE_DIM = 100
+# Big enough so that summaries are never recorded.
+# Lower this value if would like to benchmark with some summaries.
+SUMMARY_INTERVAL = 10000
+SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms
+
+
+def data_format():
+ return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
+
+
+class MnistGraphGanBenchmark(tf.test.Benchmark):
+
+ def _create_graph(self, batch_size):
+ # Generate some random data.
+ images_data = np.random.randn(batch_size, 784).astype(np.float32)
+ dataset = tf.data.Dataset.from_tensors(images_data)
+ images = dataset.repeat().make_one_shot_iterator().get_next()
+
+ # Create the models and optimizers
+ generator = mnist.Generator(data_format())
+ discriminator = mnist.Discriminator(data_format())
+ with tf.variable_scope('generator'):
+ generator_optimizer = tf.train.AdamOptimizer(0.001)
+ with tf.variable_scope('discriminator'):
+ discriminator_optimizer = tf.train.AdamOptimizer(0.001)
+
+ # Run models and compute loss
+ noise_placeholder = tf.placeholder(tf.float32,
+ shape=[batch_size, NOISE_DIM])
+ generated_images = generator(noise_placeholder)
+ tf.contrib.summary.image('generated_images',
+ tf.reshape(generated_images, [-1, 28, 28, 1]),
+ max_images=10)
+ discriminator_gen_outputs = discriminator(generated_images)
+ discriminator_real_outputs = discriminator(images)
+ generator_loss = mnist.generator_loss(discriminator_gen_outputs)
+ discriminator_loss = mnist.discriminator_loss(discriminator_real_outputs,
+ discriminator_gen_outputs)
+ # Get train ops
+ with tf.variable_scope('generator'):
+ generator_train = generator_optimizer.minimize(
+ generator_loss, var_list=generator.variables)
+ with tf.variable_scope('discriminator'):
+ discriminator_train = discriminator_optimizer.minimize(
+ discriminator_loss, var_list=discriminator.variables)
+
+ return (generator_train, discriminator_train, noise_placeholder)
+
+ def _report(self, test_name, start, num_iters, batch_size):
+ avg_time = (time.time() - start) / num_iters
+ dev = 'gpu' if tf.test.is_gpu_available() else 'cpu'
+ name = 'graph_%s_%s_batch_%d_%s' % (test_name, dev, batch_size,
+ data_format())
+ extras = {'examples_per_sec': batch_size / avg_time}
+ self.report_benchmark(
+ iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+ def benchmark_train(self):
+ for batch_size in [64, 128, 256]:
+ with tf.Graph().as_default():
+ global_step = tf.train.get_or_create_global_step()
+ increment_global_step = tf.assign_add(global_step, 1)
+ with tf.contrib.summary.create_file_writer(
+ tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(), (
+ tf.contrib.summary.record_summaries_every_n_global_steps(
+ SUMMARY_INTERVAL)):
+ (generator_train, discriminator_train, noise_placeholder
+ ) = self._create_graph(batch_size)
+
+ with tf.Session() as sess:
+ tf.contrib.summary.initialize(graph=tf.get_default_graph(),
+ session=sess)
+
+ sess.run(tf.global_variables_initializer())
+
+ num_burn, num_iters = (3, 100)
+ for _ in range(num_burn):
+ noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+ # Increment global step before evaluating summary ops to avoid
+ # race condition.
+ sess.run(increment_global_step)
+ sess.run([generator_train, discriminator_train,
+ tf.contrib.summary.all_summary_ops()],
+ feed_dict={noise_placeholder: noise})
+
+ # Run and benchmark 2 epochs
+ start = time.time()
+ for _ in range(num_iters):
+ noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+ sess.run(increment_global_step)
+ sess.run([generator_train, discriminator_train,
+ tf.contrib.summary.all_summary_ops()],
+ feed_dict={noise_placeholder: noise})
+ self._report('train', start, num_iters, batch_size)
+
+ def benchmark_generate(self):
+ for batch_size in [64, 128, 256]:
+ with tf.Graph().as_default():
+ # Using random weights. This will generate garbage.
+ generator = mnist.Generator(data_format())
+ noise_placeholder = tf.placeholder(tf.float32,
+ shape=[batch_size, NOISE_DIM])
+ generated_images = generator(noise_placeholder)
+
+ init = tf.global_variables_initializer()
+ with tf.Session() as sess:
+ sess.run(init)
+ noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+ num_burn, num_iters = (30, 1000)
+ for _ in range(num_burn):
+ sess.run(generated_images, feed_dict={noise_placeholder: noise})
+
+ start = time.time()
+ for _ in range(num_iters):
+ # Comparison with the eager execution benchmark in mnist_test.py
+ # isn't entirely fair as the time here includes the cost of copying
+ # the feeds from CPU memory to GPU.
+ sess.run(generated_images, feed_dict={noise_placeholder: noise})
+ self._report('generate', start, num_iters, batch_size)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py
new file mode 100644
index 0000000..4a3ca8d
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py
@@ -0,0 +1,113 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import time
+
+import tensorflow as tf
+
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.gan import mnist
+
+NOISE_DIM = 100
+# Big enough so that summaries are never recorded.
+# Lower this value if would like to benchmark with some summaries.
+SUMMARY_INTERVAL = 10000
+SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms
+
+
+def data_format():
+ return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
+
+
+def device():
+ return '/gpu:0' if tfe.num_gpus() else '/cpu:0'
+
+
+class MnistEagerGanBenchmark(tf.test.Benchmark):
+
+ def _report(self, test_name, start, num_iters, batch_size):
+ avg_time = (time.time() - start) / num_iters
+ dev = 'gpu' if tfe.num_gpus() else 'cpu'
+ name = 'eager_%s_%s_batch_%d_%s' % (test_name, dev, batch_size,
+ data_format())
+ extras = {'examples_per_sec': batch_size / avg_time}
+ self.report_benchmark(
+ iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+ def benchmark_train(self):
+ for batch_size in [64, 128, 256]:
+ # Generate some random data.
+ burn_batches, measure_batches = (3, 100)
+ burn_images = [tf.random_normal([batch_size, 784])
+ for _ in range(burn_batches)]
+ burn_dataset = tf.data.Dataset.from_tensor_slices(burn_images)
+ measure_images = [tf.random_normal([batch_size, 784])
+ for _ in range(measure_batches)]
+ measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images)
+
+ tf.train.get_or_create_global_step()
+ with tf.device(device()):
+ # Create the models and optimizers
+ generator = mnist.Generator(data_format())
+ discriminator = mnist.Discriminator(data_format())
+ with tf.variable_scope('generator'):
+ generator_optimizer = tf.train.AdamOptimizer(0.001)
+ with tf.variable_scope('discriminator'):
+ discriminator_optimizer = tf.train.AdamOptimizer(0.001)
+
+ with tf.contrib.summary.create_file_writer(
+ tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default():
+
+ # warm up
+ mnist.train_one_epoch(generator, discriminator, generator_optimizer,
+ discriminator_optimizer,
+ burn_dataset, log_interval=SUMMARY_INTERVAL,
+ noise_dim=NOISE_DIM)
+ # measure
+ start = time.time()
+ mnist.train_one_epoch(generator, discriminator, generator_optimizer,
+ discriminator_optimizer,
+ measure_dataset, log_interval=SUMMARY_INTERVAL,
+ noise_dim=NOISE_DIM)
+ self._report('train', start, measure_batches, batch_size)
+
+ def benchmark_generate(self):
+ for batch_size in [64, 128, 256]:
+ with tf.device(device()):
+ # Using random weights. This will generate garbage.
+ generator = mnist.Generator(data_format())
+
+ num_burn, num_iters = (30, 1000)
+ for _ in range(num_burn):
+ noise = tf.random_uniform(shape=[batch_size, NOISE_DIM],
+ minval=-1., maxval=1.)
+ generator(noise)
+
+ start = time.time()
+ for _ in range(num_iters):
+ noise = tf.random_uniform(shape=[batch_size, NOISE_DIM],
+ minval=-1., maxval=1.)
+ generator(noise)
+ self._report('generate', start, num_iters, batch_size)
+
+
+if __name__ == '__main__':
+ tfe.enable_eager_execution()
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
index 205709f..136085e 100644
--- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
+++ b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
@@ -39,22 +39,40 @@
return tf.data.Dataset.from_tensors((images, labels))
+def train_one_epoch(defun=False):
+ model = mnist.MNISTModel(data_format())
+ if defun:
+ model.call = tfe.defun(model.call)
+ optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+ dataset = random_dataset()
+ with tf.device(device()):
+ tf.train.get_or_create_global_step()
+ mnist.train_one_epoch(model, optimizer, dataset)
+
+
+def evaluate(defun=False):
+ model = mnist.MNISTModel(data_format())
+ dataset = random_dataset()
+ if defun:
+ model.call = tfe.defun(model.call)
+ with tf.device(device()):
+ tf.train.get_or_create_global_step()
+ mnist.test(model, dataset)
+
+
class MNISTTest(tf.test.TestCase):
def testTrainOneEpoch(self):
- model = mnist.MNISTModel(data_format())
- optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
- dataset = random_dataset()
- with tf.device(device()):
- tf.train.get_or_create_global_step()
- mnist.train_one_epoch(model, optimizer, dataset)
+ train_one_epoch(defun=False)
def testTest(self):
- model = mnist.MNISTModel(data_format())
- dataset = random_dataset()
- with tf.device(device()):
- tf.train.get_or_create_global_step()
- mnist.test(model, dataset)
+ evaluate(defun=False)
+
+ def testTrainOneEpochWithDefunCall(self):
+ train_one_epoch(defun=True)
+
+ def testTestWithDefunCall(self):
+ evaluate(defun=True)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index d8d8644..932f95c 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -64,14 +64,22 @@
class ResNet50Test(tf.test.TestCase):
- def test_apply(self):
+ def _apply(self, defun=False):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format)
+ if defun:
+ model.call = tfe.defun(model.call)
with tf.device(device):
images, _ = random_batch(2)
output = model(images)
self.assertEqual((2, 1000), output.shape)
+ def test_apply(self):
+ self._apply(defun=False)
+
+ def test_apply_with_defun(self):
+ self._apply(defun=True)
+
def test_apply_no_top(self):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format, include_top=False)
@@ -175,9 +183,11 @@
# a sync. This is a roundabout way, yes.
tf.constant(1.).cpu()
- def benchmark_eager_apply(self):
+ def _benchmark_eager_apply(self, label, defun=False):
device, data_format = device_and_data_format()
model = resnet50.ResNet50(data_format)
+ if defun:
+ model.call = tfe.defun(model.call)
batch_size = 64
num_burn = 5
num_iters = 30
@@ -189,16 +199,23 @@
start = time.time()
for _ in xrange(num_iters):
model(images).cpu()
- self._report('eager_apply', start, num_iters, device, batch_size,
- data_format)
+ self._report(label, start, num_iters, device, batch_size, data_format)
- def _benchmark_eager_train(self, label, make_iterator):
+ def benchmark_eager_apply(self):
+ self._benchmark_eager_apply('eager_apply', defun=False)
+
+ def benchmark_eager_apply_with_defun(self):
+ self._benchmark_eager_apply('eager_apply_with_defun', defun=True)
+
+ def _benchmark_eager_train(self, label, make_iterator, defun=False):
device, data_format = device_and_data_format()
for batch_size in self._train_batch_sizes():
(images, labels) = random_batch(batch_size)
num_burn = 3
num_iters = 10
model = resnet50.ResNet50(data_format)
+ if defun:
+ model.call = tfe.defun(model.call)
optimizer = tf.train.GradientDescentOptimizer(0.1)
with tf.device(device):
@@ -217,7 +234,11 @@
self._report(label, start, num_iters, device, batch_size, data_format)
def benchmark_eager_train(self):
- self._benchmark_eager_train('eager_train', MockIterator)
+ self._benchmark_eager_train('eager_train', MockIterator, defun=False)
+
+ def benchmark_eager_train_with_defun(self):
+ self._benchmark_eager_train(
+ 'eager_train_with_defun', MockIterator, defun=True)
def benchmark_eager_train_datasets(self):
@@ -226,7 +247,18 @@
ds = tf.data.Dataset.from_tensors(tensors).repeat()
return tfe.Iterator(ds)
- self._benchmark_eager_train('eager_train_dataset', make_iterator)
+ self._benchmark_eager_train(
+ 'eager_train_dataset', make_iterator, defun=False)
+
+ def benchmark_eager_train_datasets_with_defun(self):
+
+ def make_iterator(tensors):
+ with tf.device('/device:CPU:0'):
+ ds = tf.data.Dataset.from_tensors(tensors).repeat()
+ return tfe.Iterator(ds)
+
+ self._benchmark_eager_train(
+ 'eager_train_dataset', make_iterator, defun=True)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index 3eb4f5f..8e6b947 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -105,15 +105,13 @@
result = net(constant_op.constant([[2.0]]))
self.assertEqual(34.0, self.evaluate(result))
- # TODO(akshayka): This test should be changed once an API for compiling
- # `call` into a defun is implemented.
def testReplacingNetworkCallWithDefun(self):
net = MyNetwork(name="abcd")
+ net.call = function.defun(net.call)
x = constant_op.constant([[2.0]])
net(x) # Force variables to be created.
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
- net.call = function.defun(net.call)
result = net(x) # Build and execute the TensorFlow function
self.assertEqual(34.0, self.evaluate(result))
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 1150328..a9bdbe0 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -3237,7 +3237,11 @@
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
- images, 32, [3, 3], 2, weights_regularizer=regularizer)
+ images,
+ 32, [3, 3],
+ 2,
+ weights_regularizer=regularizer,
+ weights_initializer=init_ops.ones_initializer())
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
@@ -3245,12 +3249,31 @@
weight_decay.op.name,
'SeparableConv2d/depthwise_kernel/Regularizer/l2_regularizer')
sess.run(variables_lib.global_variables_initializer())
- self.assertLessEqual(sess.run(weight_decay), 0.05)
+ depth_weight_one = sess.run(weight_decay)
weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[1]
self.assertEqual(
weight_decay.op.name,
'SeparableConv2d/pointwise_kernel/Regularizer/l2_regularizer')
- self.assertLessEqual(sess.run(weight_decay), 0.05)
+ pointwise_weight_one = sess.run(weight_decay)
+
+ regularizer = regularizers.l2_regularizer(1.0)
+ layers_lib.separable_conv2d(
+ images,
+ 32, [3, 3],
+ 2,
+ weights_regularizer=regularizer,
+ weights_initializer=init_ops.ones_initializer())
+ self.assertEqual(
+ len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 4)
+ weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[2]
+ sess.run(variables_lib.global_variables_initializer())
+ depth_weight_two = sess.run(weight_decay)
+ weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[3]
+ pointwise_weight_two = sess.run(weight_decay)
+
+ self.assertAllClose(
+ [100.0 * depth_weight_one, 100.0 * pointwise_weight_one],
+ [depth_weight_two, pointwise_weight_two])
def testReuseConvWithWeightDecay(self):
height, width = 3, 3
diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
index 0dafb1f..a885a57 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
@@ -96,19 +96,19 @@
}
NSString* RunInferenceOnImage() {
- std::string graph;
+ NSString* graph = @"mobilenet_v1_1.0_224";
const int num_threads = 1;
std::string input_layer_type = "float";
std::vector<int> sizes = {1, 224, 224, 3};
- NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite");
+ const NSString* graph_path = FilePathForResourceName(graph, @"tflite");
std::unique_ptr<tflite::FlatBufferModel> model(
tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]));
if (!model) {
- LOG(FATAL) << "Failed to mmap model " << graph;
+ LOG(FATAL) << "Failed to mmap model " << [graph UTF8String];
}
- LOG(INFO) << "Loaded model " << graph;
+ LOG(INFO) << "Loaded model " << [graph UTF8String];
model->error_reporter();
LOG(INFO) << "resolved reporter";
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 81567ce..81df517 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -247,7 +247,8 @@
INSTANTIATE_TESTS(space_to_batch_nd)
INSTANTIATE_TESTS(batch_to_space_nd)
INSTANTIATE_TESTS(concat)
-INSTANTIATE_TESTS(constant)
+// TODO(b/71642435) re-enable this test
+// INSTANTIATE_TESTS(constant)
INSTANTIATE_TESTS(control_dep)
INSTANTIATE_TESTS(conv)
INSTANTIATE_TESTS(depthwiseconv)
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 51d76e4..90fa442 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -802,8 +802,10 @@
*reshape_op->add_input() = src_op.inputs[1];
(*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
const auto& shape_array = model.GetArray(src_op.inputs[1]);
- CHECK(shape_array.data_type == ArrayDataType::kInt32);
- CHECK(shape_array.buffer != nullptr);
+ QCHECK(shape_array.data_type == ArrayDataType::kInt32)
+ << "Only int32 shape is supported.";
+ QCHECK(shape_array.buffer != nullptr)
+ << "Shape inferred at runtime is not supported.";
const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
}
diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
index 95dfd8f..764ab62 100644
--- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py
+++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
@@ -210,7 +210,7 @@
return self.activation(outputs)
return outputs
- def _compute_output_shape(self, input_shape):
+ def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
if self.data_format == 'channels_last':
space = input_shape[1:-1]
@@ -467,7 +467,7 @@
return self.activation(outputs) # pylint: disable=not-callable
return outputs
- def _compute_output_shape(self, input_shape):
+ def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
input_shape = input_shape.with_rank_at_least(2)
if input_shape[-1].value is None:
diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD
index 4322852..dca380c 100644
--- a/tensorflow/contrib/py2tf/pyct/BUILD
+++ b/tensorflow/contrib/py2tf/pyct/BUILD
@@ -22,6 +22,7 @@
"compiler.py",
"parser.py",
"pretty_printer.py",
+ "templates.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
@@ -30,6 +31,10 @@
py_test(
name = "anno_test",
srcs = ["anno_test.py"],
+ tags = [
+ "manual",
+ "notap",
+ ],
deps = [
":pyct",
"//tensorflow/python:client_testlib",
@@ -74,3 +79,16 @@
"//tensorflow/python:client_testlib",
],
)
+
+py_test(
+ name = "templates_test",
+ srcs = ["templates_test.py"],
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py
new file mode 100644
index 0000000..6acc03b
--- /dev/null
+++ b/tensorflow/contrib/py2tf/pyct/templates.py
@@ -0,0 +1,112 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST conversion templates.
+
+Adapted from Tangent.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import parser
+
+
+class ReplaceTransformer(gast.NodeTransformer):
+ """Replace AST nodes."""
+
+ def __init__(self, replacements):
+ """Create a new ReplaceTransformer.
+
+ Args:
+ replacements: A mapping from placeholder names to (lists of) AST nodes
+ that these placeholders will be replaced by.
+ """
+ self.replacements = replacements
+
+ # TODO(mdan): Make a more detailed pass and clean up if needed.
+
+ def visit_Expr(self, node):
+ if (isinstance(node.value, gast.Name) and
+ node.value.id in self.replacements):
+ return self.visit(node.value)
+ self.generic_visit(node)
+ return node
+
+ def visit_FunctionDef(self, node):
+ node = self.generic_visit(node)
+ if node.name in self.replacements:
+ repl = self.replacements[node.name]
+ if not isinstance(repl, (gast.Name, ast.Name)):
+ raise ValueError(
+ 'A function name can only be replaced by a Name node. Found: %s',
+ repl)
+ node.name = repl.id
+ return node
+
+ def visit_Name(self, node):
+ # Note: The caller is reposnsible with making sure the replacement
+ # Name nodes have the proper ctx set up.
+ # TODO(mdan): Is it possible to always infer the proper context here?
+ if node.id in self.replacements:
+ # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
+ new_nodes = self.replacements[node.id]
+ if isinstance(new_nodes, gast.AST):
+ new_nodes = [new_nodes]
+ if len(new_nodes) == 1:
+ new_nodes, = new_nodes
+ return new_nodes
+ else:
+ return node
+
+
+def replace(template, **replacements):
+ """Replace placeholders in a Python template.
+
+ Args:
+ template: A function to be used as a template. Any placeholder is expected
+ to also be a function argument.
+ **replacements: A mapping from placeholder names to (lists of) AST nodes
+ that these placeholders will be replaced by.
+
+ Returns:
+ body: An AST node or list of AST nodes with the replacements made. If the
+ template was a function, a list will be returned. If the template was a
+ node, the same node will be returned. If the template was a string, an
+ AST node will be returned (a `Module` node in the case of a multi-line
+ string, an `Expr` node otherwise).
+
+ Raises:
+ ValueError: If a function is used as a template and an incorrect set of
+ replacements was passed.
+ """
+ tree = parser.parse_object(template).body[0]
+ placeholders = set(arg.id for arg in tree.args.args)
+ tree.args.args = []
+ if tree.args.vararg:
+ placeholders.add(tree.args.vararg)
+ tree.args.vararg = None
+ if set(replacements.keys()) != placeholders:
+ raise ValueError(
+ 'too many or few replacements. replacements: %s; placeholders: %s' %
+ (replacements.keys(), placeholders))
+
+ # Perform the replacement, stripping the function into which the template was
+ # wrapped.
+ return ReplaceTransformer(replacements).visit(tree).body
diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py
new file mode 100644
index 0000000..2ad8b93
--- /dev/null
+++ b/tensorflow/contrib/py2tf/pyct/templates_test.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for templates module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.python.platform import test
+
+
+class TemplatesTest(test.TestCase):
+
+ def test_replace_variable(self):
+ def template(a): # pylint:disable=unused-argument
+ def test_fn(a): # pylint:disable=unused-variable
+ a += 1
+ a = 2 * a + 1
+ return b # pylint:disable=undefined-variable
+
+ node = templates.replace(
+ template, a=gast.Name('b', gast.Load(), None))[0]
+ result = compiler.ast_to_object(node)
+ self.assertEquals(7, result.test_fn(2))
+
+ def test_replace_function_name(self):
+ def template(fname): # pylint:disable=unused-argument
+ def fname(a): # pylint:disable=function-redefined
+ a += 1
+ a = 2 * a + 1
+ return a
+
+ node = templates.replace(
+ template, fname=gast.Name('test_fn', gast.Load(), None))[0]
+ result = compiler.ast_to_object(node)
+ self.assertEquals(7, result.test_fn(2))
+
+ def test_code_block(self):
+ def template(block): # pylint:disable=unused-argument
+ def test_fn(a): # pylint:disable=unused-variable
+ block # pylint:disable=pointless-statement
+ return a
+
+ node = templates.replace(
+ template,
+ block=[
+ gast.Assign(
+ [
+ gast.Name('a', gast.Store(), None)
+ ],
+ gast.BinOp(
+ gast.Name('a', gast.Load(), None),
+ gast.Add(),
+ gast.Num(1))),
+ ] * 2)[0]
+ result = compiler.ast_to_object(node)
+ self.assertEquals(3, result.test_fn(1))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 46823fa..7378920 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -845,12 +845,14 @@
batch_size = 3
input_size = 4
expected_state_c = np.array(
- [[0.00072015, 0.00036633], [0.00083481, 0.00047266],
- [0.00085111, 0.00053054]],
+ [[6.450831e-04, 4.697885e-04],
+ [9.862894e-05, 7.212213e-04],
+ [4.401947e-04, 9.143004e-04]],
dtype=np.float32)
expected_state_h = np.array(
- [[0.0005159, 0.00026243], [0.00062958, 0.00035646],
- [0.00064732, 0.00040351]],
+ [[4.621217e-04, 3.365449e-04],
+ [7.438179e-05, 5.439147e-04],
+ [3.347936e-04, 6.953785e-04]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@@ -1328,7 +1330,7 @@
h_low = 0.761552567265
h_high = 0.995008519604
num_units = 5
- allowed_low = [2, 3]
+ allowed_low = [1, 2, 3]
with self.test_session() as sess:
with variable_scope.variable_scope(
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index e5d5917..7465f20 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -69,7 +69,7 @@
def assertAllCloseOrEqual(self, x, y, **kwargs):
if isinstance(x, np.ndarray) or isinstance(x, float):
return super(AttentionWrapperTest, self).assertAllClose(
- x, y, atol=1e-4, **kwargs)
+ x, y, atol=1e-3, **kwargs)
else:
self.assertAllEqual(x, y, **kwargs)
@@ -276,7 +276,7 @@
rnn_output=ResultSummary(
shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00597103),
sample_id=ResultSummary(
- shape=(5, 3), dtype=dtype('int32'), mean=1.4))
+ shape=(5, 3), dtype=dtype('int32'), mean=1.6))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=ResultSummary(
@@ -305,7 +305,7 @@
rnn_output=ResultSummary(
shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386),
sample_id=ResultSummary(
- shape=(5, 3), dtype=dtype('int32'), mean=1.4666666666666666))
+ shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=ResultSummary(
@@ -336,7 +336,7 @@
rnn_output=ResultSummary(
shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386),
sample_id=ResultSummary(
- shape=(5, 3), dtype=dtype('int32'), mean=1.4666666666666666))
+ shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333333333))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=ResultSummary(
@@ -578,7 +578,7 @@
rnn_output=ResultSummary(
shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0025896581),
sample_id=ResultSummary(
- shape=(5, 3), dtype=dtype('int32'), mean=1.8666666666666667))
+ shape=(5, 3), dtype=dtype('int32'), mean=1.6))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=ResultSummary(
@@ -594,7 +594,7 @@
shape=(5, 8), dtype=dtype('float32'), mean=0.028698336),
alignment_history=())
expected_final_alignment_history = ResultSummary(
- shape=(3, 5, 8), dtype=dtype('float32'), mean=0.046009291)
+ shape=(3, 5, 8), dtype=dtype('float32'), mean=0.04865776002407074)
self._testWithAttention(
create_attention_mechanism,
@@ -761,9 +761,9 @@
expected_final_output = BasicDecoderOutput(
rnn_output=ResultSummary(
- shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11691988),
+ shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11798714846372604),
sample_id=ResultSummary(
- shape=(5, 3), dtype=dtype('int32'), mean=7.2666666666666666))
+ shape=(5, 3), dtype=dtype('int32'), mean=7.933333333333334))
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=ResultSummary(
@@ -771,7 +771,7 @@
h=ResultSummary(
shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)),
attention=ResultSummary(
- shape=(5, 20), dtype=dtype('float32'), mean=0.11680689),
+ shape=(5, 20), dtype=dtype('float32'), mean=0.11798714846372604),
time=3,
alignments=(
ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index cddd628..a998ac1 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -237,8 +237,7 @@
if params.inference_tree_paths:
model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
- if params.regression:
- model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+ model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
return model_ops
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index eb93876..3650b5d 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -478,8 +478,7 @@
**inference_args: Keyword arguments to pass through to each tree.
Returns:
- A tuple of (probabilities, tree_paths, variance), where variance
- is the variance over all the trees for regression problems only.
+ A tuple of (probabilities, tree_paths, variance).
Raises:
NotImplementedError: If trying to use feature bagging with sparse
@@ -513,13 +512,12 @@
self.params.num_trees,
name='probabilities')
tree_paths = array_ops.stack(paths, axis=1)
- regression_variance = None
- if self.params.regression:
- expected_squares = math_ops.div(
- math_ops.reduce_sum(all_predict * all_predict, 1),
- self.params.num_trees)
- regression_variance = math_ops.maximum(
- 0., expected_squares - average_values * average_values)
+
+ expected_squares = math_ops.div(
+ math_ops.reduce_sum(all_predict * all_predict, 1),
+ self.params.num_trees)
+ regression_variance = math_ops.maximum(
+ 0., expected_squares - average_values * average_values)
return average_values, tree_paths, regression_variance
def average_size(self):
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 113dfb8..bbe627b 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -108,7 +108,7 @@
probs, paths, var = graph_builder.inference_graph(input_data)
self.assertTrue(isinstance(probs, ops.Tensor))
self.assertTrue(isinstance(paths, ops.Tensor))
- self.assertIsNone(var)
+ self.assertTrue(isinstance(var, ops.Tensor))
def testTrainingConstructionClassificationSparse(self):
input_data = sparse_tensor.SparseTensor(
diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc
index 1aff789..0514fce 100644
--- a/tensorflow/contrib/tensorboard/db/schema.cc
+++ b/tensorflow/contrib/tensorboard/db/schema.cc
@@ -19,7 +19,8 @@
Status Run(Sqlite* db, const char* sql) {
auto stmt = db->Prepare(sql);
- TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt.StepAndReset(), sql);
+ TF_RETURN_IF_ERROR(stmt.status());
+ TF_RETURN_IF_ERROR(stmt.ValueOrDie().StepAndReset());
return Status::OK();
}
@@ -28,11 +29,11 @@
Status SetupTensorboardSqliteDb(Sqlite* db) {
// Note: GCC raw strings macros are broken.
// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971
- db->Prepare(strings::StrCat("PRAGMA application_id=",
- kTensorboardSqliteApplicationId))
- .StepAndReset()
- .IgnoreError();
- db->Prepare("PRAGMA user_version=0").StepAndReset().IgnoreError();
+ TF_RETURN_IF_ERROR(
+ db->PrepareOrDie(strings::StrCat("PRAGMA application_id=",
+ kTensorboardSqliteApplicationId))
+ .StepAndReset());
+ db->PrepareOrDie("PRAGMA user_version=0").StepAndResetOrDie();
Status s;
// Creates Ids table.
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index fc201be..a9d75e9 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -133,7 +133,8 @@
class IdAllocator {
public:
IdAllocator(Env* env, Sqlite* db)
- : env_{env}, inserter_{db->Prepare("INSERT INTO Ids (id) VALUES (?)")} {}
+ : env_{env},
+ inserter_{db->PrepareOrDie("INSERT INTO Ids (id) VALUES (?)")} {}
Status CreateNewId(int64* id) {
Status s;
@@ -208,7 +209,7 @@
}
Status SaveNodeInputs() {
- auto insert = db_->Prepare(R"sql(
+ auto insert = db_->PrepareOrDie(R"sql(
INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control)
VALUES (?, ?, ?, ?, ?)
)sql");
@@ -236,7 +237,7 @@
}
Status SaveNodes() {
- auto insert = db_->Prepare(R"sql(
+ auto insert = db_->PrepareOrDie(R"sql(
INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def)
VALUES (?, ?, ?, ?, ?, snap(?))
)sql");
@@ -262,7 +263,7 @@
}
Status SaveGraph() {
- auto insert = db_->Prepare(R"sql(
+ auto insert = db_->PrepareOrDie(R"sql(
INSERT INTO Graphs (graph_id, inserted_time, graph_def)
VALUES (?, ?, snap(?))
)sql");
@@ -291,14 +292,14 @@
experiment_name_{experiment_name},
run_name_{run_name},
user_name_{user_name},
- insert_tensor_{db_->Prepare(R"sql(
+ insert_tensor_{db_->PrepareOrDie(R"sql(
INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor)
VALUES (?, ?, ?, snap(?))
)sql")} {}
~RunWriter() {
if (run_id_ == kAbsent) return;
- auto update = db_->Prepare(R"sql(
+ auto update = db_->PrepareOrDie(R"sql(
UPDATE Runs SET finished_time = ? WHERE run_id = ?
)sql");
update.BindDouble(1, GetWallTime(env_));
@@ -331,7 +332,8 @@
TF_RETURN_IF_ERROR(
GraphSaver::Save(env_, db_.get(), &id_allocator_, g.get(), &graph_id));
if (run_id_ != kAbsent) {
- auto set = db_->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
+ auto set =
+ db_->PrepareOrDie("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
set.BindInt(1, graph_id);
set.BindInt(2, run_id_);
TF_RETURN_IF_ERROR(set.StepAndReset());
@@ -350,14 +352,14 @@
TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(tag_id));
tag_ids_[tag_name] = *tag_id;
if (!metadata.summary_description().empty()) {
- SqliteStatement insert_description = db_->Prepare(R"sql(
+ SqliteStatement insert_description = db_->PrepareOrDie(R"sql(
INSERT INTO Descriptions (id, description) VALUES (?, ?)
)sql");
insert_description.BindInt(1, *tag_id);
insert_description.BindText(2, metadata.summary_description());
TF_RETURN_IF_ERROR(insert_description.StepAndReset());
}
- SqliteStatement insert = db_->Prepare(R"sql(
+ SqliteStatement insert = db_->PrepareOrDie(R"sql(
INSERT INTO Tags (
run_id,
tag_id,
@@ -387,7 +389,7 @@
private:
Status InitializeUser() {
if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
- SqliteStatement get = db_->Prepare(R"sql(
+ SqliteStatement get = db_->PrepareOrDie(R"sql(
SELECT user_id FROM Users WHERE user_name = ?
)sql");
get.BindText(1, user_name_);
@@ -398,7 +400,7 @@
return Status::OK();
}
TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&user_id_));
- SqliteStatement insert = db_->Prepare(R"sql(
+ SqliteStatement insert = db_->PrepareOrDie(R"sql(
INSERT INTO Users (user_id, user_name, inserted_time) VALUES (?, ?, ?)
)sql");
insert.BindInt(1, user_id_);
@@ -412,7 +414,7 @@
if (experiment_name_.empty()) return Status::OK();
if (experiment_id_ == kAbsent) {
TF_RETURN_IF_ERROR(InitializeUser());
- SqliteStatement get = db_->Prepare(R"sql(
+ SqliteStatement get = db_->PrepareOrDie(R"sql(
SELECT
experiment_id,
started_time
@@ -432,7 +434,7 @@
} else {
TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&experiment_id_));
experiment_started_time_ = computed_time;
- SqliteStatement insert = db_->Prepare(R"sql(
+ SqliteStatement insert = db_->PrepareOrDie(R"sql(
INSERT INTO Experiments (
user_id,
experiment_id,
@@ -451,7 +453,7 @@
}
if (computed_time < experiment_started_time_) {
experiment_started_time_ = computed_time;
- SqliteStatement update = db_->Prepare(R"sql(
+ SqliteStatement update = db_->PrepareOrDie(R"sql(
UPDATE Experiments SET started_time = ? WHERE experiment_id = ?
)sql");
update.BindDouble(1, computed_time);
@@ -467,7 +469,7 @@
if (run_id_ == kAbsent) {
TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&run_id_));
run_started_time_ = computed_time;
- SqliteStatement insert = db_->Prepare(R"sql(
+ SqliteStatement insert = db_->PrepareOrDie(R"sql(
INSERT OR REPLACE INTO Runs (
experiment_id,
run_id,
@@ -485,7 +487,7 @@
}
if (computed_time < run_started_time_) {
run_started_time_ = computed_time;
- SqliteStatement update = db_->Prepare(R"sql(
+ SqliteStatement update = db_->PrepareOrDie(R"sql(
UPDATE Runs SET started_time = ? WHERE run_id = ?
)sql");
update.BindDouble(1, computed_time);
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index 5ea844b..cfc6192 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -48,7 +48,7 @@
class SummaryDbWriterTest : public ::testing::Test {
protected:
- void SetUp() override { db_ = Sqlite::Open(":memory:").ValueOrDie(); }
+ void SetUp() override { db_ = Sqlite::OpenOrDie(":memory:"); }
void TearDown() override {
if (writer_ != nullptr) {
@@ -58,7 +58,7 @@
}
int64 QueryInt(const string& sql) {
- SqliteStatement stmt = db_->Prepare(sql);
+ SqliteStatement stmt = db_->PrepareOrDie(sql);
bool is_done;
Status s = stmt.Step(&is_done);
if (!s.ok() || is_done) {
@@ -69,7 +69,7 @@
}
double QueryDouble(const string& sql) {
- SqliteStatement stmt = db_->Prepare(sql);
+ SqliteStatement stmt = db_->PrepareOrDie(sql);
bool is_done;
Status s = stmt.Step(&is_done);
if (!s.ok() || is_done) {
@@ -80,7 +80,7 @@
}
string QueryString(const string& sql) {
- SqliteStatement stmt = db_->Prepare(sql);
+ SqliteStatement stmt = db_->PrepareOrDie(sql);
bool is_done;
Status s = stmt.Step(&is_done);
if (!s.ok() || is_done) {
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index ff100ca..848e5a6 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -176,7 +176,9 @@
additional_deps = [
":tpu",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:framework",
+ "//tensorflow/python:layers",
],
)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 79cda18..8fec379 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -142,8 +142,9 @@
def _AddOpInternal(self, op):
# pylint: disable=protected-access
if op.type in _BLACKLISTED_OPS:
- raise ValueError("Operation of type %s (%s) is not supported on the TPU" %
- (op.type, op.name))
+ logging.error("Operation of type %s (%s) is not supported on the TPU. "
+ "Execution will fail if this op is used in the graph. " %
+ (op.type, op.name))
if op.type in _NOT_IMPLEMENTED_OPS:
self._unsupported_ops.append(op)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
index 2de5419..336d826 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
@@ -20,8 +20,14 @@
from __future__ import print_function
from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import tpu_feed
+from tensorflow.contrib.tpu.python.tpu import training_loop
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@@ -39,5 +45,36 @@
self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
+class TPULayerRewriteTest(test.TestCase):
+
+ def testUsingInfeedQueueWithRegularizer(self):
+ """Test that Layer regularizers can reference data created in loops."""
+
+ def make_regularizer(scale):
+ return lambda inputs: scale * math_ops.reduce_sum(math_ops.square(inputs))
+
+ def training_step(inputs, scale):
+ outputs = convolutional.conv2d(
+ inputs,
+ filters=16,
+ kernel_size=(3, 3),
+ data_format="channels_first",
+ kernel_regularizer=make_regularizer(scale))
+ loss = math_ops.reduce_mean(math_ops.square(outputs))
+ return loss.op
+
+ inputs = array_ops.zeros(shape=(128, 32, 32, 16))
+ scale = array_ops.ones(shape=())
+ infeed = tpu_feed.InfeedQueue(
+ tuple_types=[dtypes.float32, dtypes.float32],
+ tuple_shapes=[inputs.shape, scale.shape])
+
+ def loop():
+ return training_loop.repeat(5, training_step, infeed_queue=infeed)
+
+ # This should not throw an error.
+ tpu.rewrite(loop)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index d032cf9..ef8ca3f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -363,6 +363,23 @@
],
)
+cc_library(
+ name = "abi",
+ srcs = ["platform/abi.cc"],
+ hdrs = ["platform/abi.h"],
+)
+
+cc_library(
+ name = "stacktrace_handler",
+ srcs = ["platform/stacktrace_handler.cc"],
+ hdrs = ["platform/stacktrace_handler.h"],
+ deps = [
+ ":abi",
+ ":lib",
+ ":lib_platform",
+ ],
+)
+
# Test support library needed for all tests
# This is currently public, but may be made internal in the
# future. Try to avoid depending on it.
@@ -2359,6 +2376,7 @@
deps = [
":lib",
":lib_internal",
+ ":stacktrace_handler",
":test", # buildcleaner: keep
"//tensorflow/core/platform/default/build_config:test_main",
],
@@ -2429,6 +2447,7 @@
"platform/net_test.cc",
"platform/port_test.cc",
"platform/profile_utils/cpu_utils_test.cc",
+ "platform/stacktrace_handler_test.cc",
"platform/subprocess_test.cc",
],
deps = [
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 286266a..51d7f98 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -152,7 +152,6 @@
~FunctionLibraryRuntimeImpl() override;
Status Instantiate(const string& function_name, AttrSlice attrs,
- const InstantiateOptions& options,
Handle* handle) override;
Status ReleaseHandle(Handle handle) override;
@@ -224,7 +223,7 @@
Status GetOrCreateItem(Handle handle, Item** item);
Status InstantiateSymbolicGradient(const NameAttrList& func,
FunctionBody** g_body);
- bool IsLocalTarget(const InstantiateOptions& options);
+ bool IsLocalTarget(const AttrSlice& attrs);
AttrValueMap FixAttrs(const AttrSlice& attrs);
void RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
@@ -353,8 +352,7 @@
// Try to instantiate this function for the func/attr. Maybe it's
// cached already.
Handle handle;
- TF_RETURN_IF_ERROR(
- Instantiate(ndef.op(), AttrSlice(&ndef.attr()), {}, &handle));
+ TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
const FunctionBody* fbody = GetFunctionBody(handle);
CHECK_NOTNULL(fbody);
@@ -413,7 +411,7 @@
// f is a user-defined function.
Handle f_handle;
TF_RETURN_IF_ERROR(
- Instantiate(func.name(), AttrSlice(&func.attr()), {}, &f_handle));
+ Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle));
const FunctionBody* f_body = GetFunctionBody(f_handle);
CHECK_NOTNULL(f_body);
*g_body = SymbolicGradient(*f_body);
@@ -421,25 +419,42 @@
return Status::OK();
}
-bool FunctionLibraryRuntimeImpl::IsLocalTarget(
- const InstantiateOptions& options) {
+bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) {
if (device_ == nullptr) return true;
- if (options.target.empty()) return true;
+ string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+ if (target.empty()) return true;
Device* target_device;
- if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
+ if (!device_mgr_->LookupDevice(target, &target_device).ok()) {
return false;
}
return target_device == device_;
}
-Status FunctionLibraryRuntimeImpl::Instantiate(
- const string& function_name, AttrSlice attrs,
- const InstantiateOptions& options, Handle* handle) {
- if (!IsLocalTarget(options)) {
- return parent_->Instantiate(function_name, attrs, options, handle);
+AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) {
+ AttrValueMap value_map;
+ for (auto it : attrs) {
+ value_map[it.first] = it.second;
+ }
+ if (attrs.Find("_target") != nullptr) {
+ return value_map;
+ }
+ AttrValue v;
+ v.set_s(device_name_);
+ AddAttr("_target", v, &value_map);
+ return value_map;
+}
+
+Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
+ AttrSlice attrs,
+ Handle* handle) {
+ AttrValueMap value_map = FixAttrs(attrs);
+ AttrSlice new_attrs(&value_map);
+
+ if (!IsLocalTarget(new_attrs)) {
+ return parent_->Instantiate(function_name, new_attrs, handle);
}
- const string key = Canonicalize(function_name, attrs, options);
+ const string key = Canonicalize(function_name, new_attrs);
*handle = parent_->GetHandle(key);
if (*handle != kInvalidHandle) {
return Status::OK();
@@ -448,7 +463,7 @@
Status s;
FunctionBody* fbody = nullptr;
if (function_name == kGradientOp) {
- const AttrValue* f = attrs.Find(kFuncAttr);
+ const AttrValue* f = new_attrs.Find(kFuncAttr);
if (f == nullptr) {
return errors::InvalidArgument("SymbolicGradient is missing attr: f");
}
@@ -458,7 +473,7 @@
}
const string grad = lib_def_->FindGradient(func.name());
if (!grad.empty()) {
- return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
+ return Instantiate(grad, AttrSlice(&func.attr()), handle);
}
TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody));
} else {
@@ -466,7 +481,7 @@
if (fdef == nullptr) {
return errors::NotFound("Function ", function_name, " is not defined.");
}
- TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody));
+ TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, new_attrs, &fbody));
}
{
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 2dacace..d4181ff 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -191,14 +191,11 @@
Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
test::function::Attrs attrs,
FunctionLibraryRuntime::Handle* handle) {
- return flr->Instantiate(name, attrs, handle);
- }
-
- Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
- test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
- FunctionLibraryRuntime::Handle* handle) {
- return flr->Instantiate(name, attrs, options, handle);
+ Status status = flr->Instantiate(name, attrs, handle);
+ if (!status.ok()) {
+ return status;
+ }
+ return Status::OK();
}
Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
@@ -1091,7 +1088,8 @@
TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Handle handle;
- TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, {"/device:CPU:1"}, &handle));
+ TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {{"_target", "/device:CPU:1"}},
+ &handle));
Tensor y;
FunctionLibraryRuntime::Options opts;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 12947e2..53a1412 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -88,6 +88,16 @@
std::move(custom_kernel_creator), nullptr /* cluster_flr */) {}
/* static */
+string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
+ const AttrSlice& attrs) {
+ const AttrValue* value;
+ if (!attrs.Find("_target", &value).ok()) {
+ return "";
+ }
+ return DeviceNameUtils::CanonicalizeDeviceName(value->s());
+}
+
+/* static */
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
const string& key_prefix, int64 src_incarnation,
@@ -230,23 +240,22 @@
Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle) {
*handle = kInvalidHandle;
- FunctionLibraryRuntime* flr = GetFLR(options.target);
+ string target = ObtainFunctionTarget(attrs);
+ FunctionLibraryRuntime* flr = GetFLR(target);
if (flr != nullptr) {
- return flr->Instantiate(function_name, attrs, options, handle);
+ return flr->Instantiate(function_name, attrs, handle);
}
if (parent_ == nullptr) {
return errors::Internal(
- "Currently don't support instantiating functions on device: ",
- options.target);
+ "Currently don't support instantiating functions on device: ", target);
}
FunctionLibraryRuntime::Handle cluster_handle;
- TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs,
- options, &cluster_handle));
+ TF_RETURN_IF_ERROR(
+ parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle));
string function_key = Canonicalize(function_name, attrs);
- *handle = AddHandle(function_key, options.target, cluster_handle);
+ *handle = AddHandle(function_key, target, cluster_handle);
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 38003b7..3aa7b87 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -53,6 +53,11 @@
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator);
+ // Given a list of attrs on a function, extracts the "_target" attribute which
+ // indicates which device to run the function on. If it can't find the _target
+ // attribute, returns "". Canonicalizes the device name.
+ static string ObtainFunctionTarget(const AttrSlice& attrs);
+
// Sends `tensors_to_send` from `source_device` to `target_device` using
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
// Rendezvous. `device_context` should be the DeviceContext of the device
@@ -116,7 +121,6 @@
// Allows for function_name to be instantiated on different devices
// as specified in attrs.
Status Instantiate(const string& function_name, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::Handle* handle);
// Delegates to the local FLR that owns state corresponding to `handle` and
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index f11b7a8..270e46d 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -49,12 +49,10 @@
}
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
- test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ test::function::Attrs attrs, const std::vector<Tensor>& args,
+ std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
- Status status =
- proc_flr_->Instantiate(name, attrs, instantiate_opts, &handle);
+ Status status = proc_flr_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
return status;
}
@@ -144,6 +142,21 @@
rendezvous_->Unref();
}
+TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
+ AttrSlice empty_attrs;
+ string target =
+ ProcessFunctionLibraryRuntime::ObtainFunctionTarget(empty_attrs);
+ EXPECT_EQ("", target);
+
+ AttrValueMap attr_values;
+ AttrValue v;
+ v.set_s("/job:a/replica:0/task:0/cpu:1");
+ AddAttr("_target", v, &attr_values);
+ AttrSlice attrs(&attr_values);
+ target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+ EXPECT_EQ("/job:a/replica:0/task:0/device:CPU:1", target);
+}
+
TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) {
Init({});
int64 incarnation;
@@ -165,8 +178,10 @@
opts.remote_execution = true;
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
- TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}},
- {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+ TF_CHECK_OK(
+ Run("XTimesTwo", opts,
+ {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+ {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
rendezvous_->Unref();
}
@@ -178,8 +193,8 @@
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(
- Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:0"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts,
+ {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
TensorShape({})));
@@ -194,11 +209,15 @@
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}},
- {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+ TF_CHECK_OK(
+ Run("XTimesTwo", opts,
+ {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+ {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
- TF_CHECK_OK(Run("XTimesFour", opts, {{"T", DT_FLOAT}},
- {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+ TF_CHECK_OK(
+ Run("XTimesFour", opts,
+ {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+ {&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
rendezvous_->Unref();
}
@@ -210,13 +229,13 @@
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(
- Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts,
+ {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
- TF_CHECK_OK(
- Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts,
+ {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
@@ -230,13 +249,11 @@
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
- TF_CHECK_OK(Run("FindDevice", opts, {},
- {"/job:a/replica:0/task:0/device:CPU:0"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
TensorShape({})));
- TF_CHECK_OK(Run("FindDevice", opts, {},
- {"/job:a/replica:0/task:0/device:CPU:1"}, {}, {&y}));
+ TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
TensorShape({})));
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
index 3a8d591..d84b69d 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
@@ -26,10 +26,10 @@
/* static */
Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
- const OpDef& sig, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
+ const OpDef& sig, AttrSlice attrs, GraphDef* g,
std::vector<string>* send_keys, std::vector<string>* recv_keys) {
- const string& target = options.target;
+ const string& target =
+ ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
// Construct recv nodes for each input argument.
int i = 0;
for (const auto& in : sig.input_arg()) {
@@ -119,16 +119,16 @@
Status ClusterFunctionLibraryRuntime::Instantiate(
const string& function_name, const FunctionLibraryDefinition& lib_def,
- AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
- FunctionLibraryRuntime::LocalHandle* handle) {
- WorkerInterface* wi =
- worker_session_->worker_cache->CreateWorker(options.target);
+ AttrSlice attrs, FunctionLibraryRuntime::LocalHandle* handle) {
+ const string& target =
+ ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+ WorkerInterface* wi = worker_session_->worker_cache->CreateWorker(target);
if (wi == nullptr) {
std::vector<string> workers;
worker_session_->worker_cache->ListWorkers(&workers);
return errors::InvalidArgument(
- "Could not find worker with target: ", options.target,
+ "Could not find worker with target: ", target,
" Available workers: ", str_util::Join(workers, ", "));
}
@@ -137,8 +137,8 @@
const OpDef& sig = fdef->signature();
GraphDef gdef;
std::vector<string> send_keys, recv_keys;
- TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, &gdef,
- &send_keys, &recv_keys));
+ TF_RETURN_IF_ERROR(
+ ConstructFunctionGraph(sig, attrs, &gdef, &send_keys, &recv_keys));
*gdef.mutable_library() = lib_def.ToProto();
RegisterGraphRequest req;
@@ -152,8 +152,8 @@
mutex_lock l(mu_);
*handle = function_data_.size();
- function_data_.push_back(FunctionData(resp.graph_handle(), options.target, wi,
- send_keys, recv_keys));
+ function_data_.push_back(
+ FunctionData(resp.graph_handle(), target, wi, send_keys, recv_keys));
return Status::OK();
}
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
index 3deb80d..dd4ea68 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
@@ -34,7 +34,6 @@
Status Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* handle) override;
void Run(const FunctionLibraryRuntime::Options& opts,
@@ -43,10 +42,10 @@
FunctionLibraryRuntime::DoneCallback done) override;
private:
- static Status ConstructFunctionGraph(
- const OpDef& sig, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
- std::vector<string>* send_keys, std::vector<string>* recv_keys);
+ static Status ConstructFunctionGraph(const OpDef& sig, AttrSlice attrs,
+ GraphDef* g,
+ std::vector<string>* send_keys,
+ std::vector<string>* recv_keys);
friend class ClusterFunctionLibraryRuntimeTest;
mutable mutex mu_;
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
index 98512bc..6dd8b9e 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
@@ -47,31 +47,30 @@
new ClusterFunctionLibraryRuntime(worker_session_.get()));
}
- Status ConstructFunctionGraphHelper(
- const OpDef& sig, test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
- std::vector<string>* send_keys, std::vector<string>* recv_keys) {
+ Status ConstructFunctionGraphHelper(const OpDef& sig,
+ test::function::Attrs attrs, GraphDef* g,
+ std::vector<string>* send_keys,
+ std::vector<string>* recv_keys) {
return ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
- sig, attrs, options, g, send_keys, recv_keys);
+ sig, attrs, g, send_keys, recv_keys);
}
Status Instantiate(const string& function_name,
const FunctionLibraryDefinition& lib_def,
test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
FunctionLibraryRuntime::LocalHandle* local_handle) {
- return cluster_flr_->Instantiate(function_name, lib_def, attrs, options,
+ return cluster_flr_->Instantiate(function_name, lib_def, attrs,
local_handle);
}
- Status InstantiateAndRun(
- const string& function_name, const FunctionLibraryDefinition& lib_def,
- test::function::Attrs attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
- const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+ Status InstantiateAndRun(const string& function_name,
+ const FunctionLibraryDefinition& lib_def,
+ test::function::Attrs attrs,
+ const std::vector<Tensor>& args,
+ std::vector<Tensor*> rets) {
FunctionLibraryRuntime::LocalHandle handle;
- TF_RETURN_IF_ERROR(cluster_flr_->Instantiate(function_name, lib_def, attrs,
- options, &handle));
+ TF_RETURN_IF_ERROR(
+ cluster_flr_->Instantiate(function_name, lib_def, attrs, &handle));
Notification done;
FunctionLibraryRuntime::Options opts;
@@ -104,9 +103,9 @@
GraphDef actual;
std::vector<string> send_keys, recv_keys;
TF_CHECK_OK(ConstructFunctionGraphHelper(
- test::function::Swap().signature(), {{"T", DT_FLOAT}},
- {"/job:a/replica:0/task:0/device:CPU:0"}, &actual, &send_keys,
- &recv_keys));
+ test::function::Swap().signature(),
+ {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, &actual,
+ &send_keys, &recv_keys));
GraphDef expected;
protobuf::TextFormat::ParseFromString(R"(
node {
@@ -206,7 +205,7 @@
attr {
key: "_target"
value {
- s: "/job:a/replica:0/task:0/device:CPU:0"
+ s: "/job:a/replica:0/task:0/cpu:0"
}
}
}
@@ -310,9 +309,9 @@
Tensor y;
auto x = test::AsTensor<int32>({1, 2, 3, 4});
- TF_EXPECT_OK(InstantiateAndRun("XTimesTwoInt32", lib_def, {},
- {"/job:localhost/replica:0/task:1/cpu:0"}, {x},
- {&y}));
+ TF_EXPECT_OK(InstantiateAndRun(
+ "XTimesTwoInt32", lib_def,
+ {{"_target", "/job:localhost/replica:0/task:1/cpu:0"}}, {x}, {&y}));
test::ExpectTensorEqual<int32>(y, test::AsTensor<int32>({2, 4, 6, 8}));
}
@@ -325,9 +324,10 @@
Tensor y1, y2;
auto x1 = test::AsTensor<float>({1, 2, 3, 4});
auto x2 = test::AsTensor<float>({4, 3, 2, 1});
- TF_EXPECT_OK(InstantiateAndRun("Swap", lib_def, {{"T", DT_FLOAT}},
- {"/job:localhost/replica:0/task:1/cpu:0"},
- {x1, x2}, {&y1, &y2}));
+ TF_EXPECT_OK(InstantiateAndRun(
+ "Swap", lib_def,
+ {{"T", DT_FLOAT}, {"_target", "/job:localhost/replica:0/task:1/cpu:0"}},
+ {x1, x2}, {&y1, &y2}));
test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({4, 3, 2, 1}));
test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({1, 2, 3, 4}));
}
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 7830154..d757e96 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -795,17 +795,12 @@
return h;
}
-string Canonicalize(const string& funcname, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options) {
+string Canonicalize(const string& funcname, AttrSlice attrs) {
std::vector<string> entries;
- entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
+ entries.reserve(attrs.size());
for (auto p : attrs) {
entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
}
- if (!options.target.empty()) {
- entries.push_back(
- strings::StrCat("_target", "=", str_util::CEscape(options.target)));
- }
std::sort(entries.begin(), entries.end());
return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e5d0e49..1a579ab 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -234,6 +234,15 @@
// same.
uint64 FunctionDefHash(const FunctionDef& fdef);
+// Returns a canonicalized string for the instantiation of the
+// function of the given "name" and attributes "attrs".
+//
+// The returned string is guaranteed to be stable within one address
+// space. But it may be change as the implementation
+// evolves. Therefore, it should not be persisted or compared across
+// address spaces.
+string Canonicalize(const string& funcname, AttrSlice attrs);
+
class CallFrameInterface {
public:
virtual ~CallFrameInterface() {}
@@ -409,23 +418,9 @@
//
// Returns OK and fills in "handle" if the instantiation succeeds.
// Otherwise returns an error and "handle" is undefined.
- struct InstantiateOptions {
- // The canonical device name of the device on which the function
- // should be instantiated. If empty, the function will be
- // instantiated on the local device.
- string target;
-
- // TODO(b/70352992): Add an API for allowing a different
- // FunctionLibraryDefinition to be overlaid on this runtime's library.
- };
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
- const InstantiateOptions& options,
Handle* handle) = 0;
- Status Instantiate(const string& function_name, AttrSlice attrs,
- Handle* handle) {
- return Instantiate(function_name, attrs, {}, handle);
- }
// Releases state associated with the handle.
virtual Status ReleaseHandle(Handle handle) = 0;
@@ -507,19 +502,6 @@
typedef uint64 LocalHandle;
};
-// Returns a canonicalized string for the instantiation of the
-// function of the given "name", attributes "attrs", and "options".
-//
-// The returned string is guaranteed to be stable within one address
-// space. But it may be change as the implementation
-// evolves. Therefore, it should not be persisted or compared across
-// address spaces.
-string Canonicalize(const string& funcname, AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options);
-inline string Canonicalize(const string& funcname, AttrSlice attrs) {
- return Canonicalize(funcname, attrs, {});
-}
-
const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
@@ -532,11 +514,10 @@
virtual ~DistributedFunctionLibraryRuntime() {}
// The _target attr in attrs determines where the function is instantiated.
- virtual Status Instantiate(
- const string& function_name, const FunctionLibraryDefinition& lib_def,
- AttrSlice attrs,
- const FunctionLibraryRuntime::InstantiateOptions& options,
- FunctionLibraryRuntime::LocalHandle* handle) = 0;
+ virtual Status Instantiate(const string& function_name,
+ const FunctionLibraryDefinition& lib_def,
+ AttrSlice attrs,
+ FunctionLibraryRuntime::LocalHandle* handle) = 0;
// opts.runner isn't used for execution.
virtual void Run(const FunctionLibraryRuntime::Options& opts,
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 4771840..743c76d 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -347,6 +347,21 @@
} // namespace
+Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
+ int input_port, DataType* input_type) {
+ DataTypeVector input_types;
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
+ if (input_types.size() > input_port) {
+ const DataType dtype = input_types[input_port];
+ *input_type = dtype;
+ return Status::OK();
+ }
+ }
+ return errors::InvalidArgument("Input ", input_port, " not found for node ",
+ node_def.name());
+}
+
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs) {
for (const auto& arg : op_def.input_arg()) {
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 4e98522..812cf1b 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -239,6 +239,11 @@
// REQUIRES: Must not use the returned value beyond the lifetime of node_def.
const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
+// Computes the input type for a specific node input.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
+ int input_port, DataType* input_type);
+
// Computes the input and output types for a specific node.
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc
index df936ef..2684100 100644
--- a/tensorflow/core/grappler/clusters/single_machine_test.cc
+++ b/tensorflow/core/grappler/clusters/single_machine_test.cc
@@ -58,6 +58,10 @@
std::unique_ptr<SingleMachine> cluster_;
};
+TEST_F(SingleMachineTest, ClusterType) {
+ CHECK_EQ("single_machine", cluster_->type());
+}
+
TEST_F(SingleMachineTest, CostModel) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
cluster_->GetDeviceNames());
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc
index fd925a6..357b306 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc
@@ -56,6 +56,10 @@
std::unique_ptr<VirtualCluster> cluster_;
};
+TEST_F(VirtualClusterTest, ClusterType) {
+ CHECK_EQ("virtual", cluster_->type());
+}
+
TEST_F(VirtualClusterTest, CostModel) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
cluster_->GetDeviceNames());
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc
index 6022c47..3168758 100644
--- a/tensorflow/core/grappler/costs/graph_memory.cc
+++ b/tensorflow/core/grappler/costs/graph_memory.cc
@@ -32,7 +32,17 @@
const std::unordered_map<string, DeviceProperties>& devices) {
VirtualCluster cluster(devices);
TF_RETURN_IF_ERROR(cluster.Provision());
- return InferDynamically(&cluster);
+ TF_RETURN_IF_ERROR(cluster.Initialize(item_));
+ RunMetadata metadata;
+ Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
+ // The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
+ // that the model would run out of memory. We still get the metadata we need
+ // out of the simulation, so we just ignore this error.
+ if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
+ return s;
+ }
+ InferFromTrace(metadata.step_stats());
+ return Status::OK();
}
Status GraphMemory::InferDynamically(Cluster* cluster) {
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index a24310a..63dfade 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -29,8 +29,8 @@
class GraphView {
public:
struct Port {
- NodeDef* node;
- int port_id;
+ NodeDef* node = nullptr;
+ int port_id = -1;
bool operator==(const Port& other) const {
return node == other.node && port_id == other.port_id;
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index e557adc..791ad34 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -279,6 +279,7 @@
":graph_optimizer",
":graph_rewriter",
":static_schedule",
+ "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 37610d2..9b7f572 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -1699,10 +1699,6 @@
int input_port;
auto input = node_map_->GetNode(node_->input(0));
ParseNodeName(node_->input(0), &input_port);
- if (IsTransposeNCHWToNHWC(input->name())) {
- input = node_map_->GetNode(input->input(0));
- ParseNodeName(input->input(0), &input_port);
- }
if (input->attr().find("_output_shapes") != input->attr().end()) {
auto shape = input->attr().at("_output_shapes").list().shape(input_port);
if (shape.dim_size() != 4) {
@@ -1745,13 +1741,28 @@
int port;
ParseNodeName(node_->input(0), &port);
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- IsPortDimsFour(*input0, port) && IsAlongAllFourDims() && IsOnGPU();
+ IsPortDimsFour(*input0, port) && IsReduceAxisSupported() &&
+ IsOnGPU();
+ }
+
+ Status CustomizedProcessing() override {
+ if (IsAlongNHW() || IsAlongHW() || IsAlongC()) {
+ DataType dtype = node_->attr().at("Tidx").type();
+ TF_RETURN_IF_ERROR(
+ UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
+ }
+ return Status::OK();
}
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
private:
- bool IsAlongAllFourDims() const {
+ bool IsReduceAxisSupported() const {
+ return IsAlongAllFourDims() || IsAlongHWC() ||
+ ((IsAlongNHW() || IsAlongHW() || IsAlongC()) && !KeepDims());
+ }
+
+ bool IsAlongAxis(const std::vector<int>& axis) const {
auto axis_node = node_map_->GetNode(node_->input(1));
if (!IsConstant(*axis_node)) {
return false;
@@ -1762,15 +1773,28 @@
if (!success) {
LOG(ERROR) << "Failed to parse TensorProto.";
}
- if (tensor.dims() == 1 && tensor.dim_size(0) == 4) {
- if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 &&
- tensor.flat<int>()(2) == 2 && tensor.flat<int>()(3) == 3) {
- return true;
+ if (tensor.dims() == 1 && tensor.dim_size(0) == axis.size()) {
+ bool along_axis = true;
+ for (int i = 0; i < axis.size(); i++) {
+ along_axis = along_axis && (tensor.flat<int>()(i) == axis[i]);
}
+ if (along_axis) return true;
}
}
return false;
}
+
+ bool IsAlongAllFourDims() const { return IsAlongAxis({0, 1, 2, 3}); }
+
+ bool IsAlongHWC() const { return IsAlongAxis({1, 2, 3}); }
+
+ bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
+
+ bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
+
+ bool IsAlongC() const { return IsAlongAxis({3}); }
+
+ bool KeepDims() const { return node_->attr().at("keep_dims").b(); }
};
class SwitchProcessor : public AgnosticNodeProcessor {
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 1420fdb..bb4839d 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -23,6 +23,7 @@
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -480,8 +481,19 @@
}
}
-std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
- GraphDef* graph) {
+Status BuildSwapPair(NodeDef* node, int input_to_swap, GraphDef* graph,
+ std::pair<NodeDef*, NodeDef*>* swap_pair) {
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
+ DataType input_type;
+ TF_RETURN_IF_ERROR(
+ InputTypeForNode(*node, *op_def, input_to_swap, &input_type));
+ if (IsRefType(input_type)) {
+ return errors::InvalidArgument("Can't swap input ", input_to_swap,
+ " of node ", node->name(),
+ " since it expects a reference");
+ }
+
string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
// Force the tensor to be copied to cpu.
@@ -501,10 +513,11 @@
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
- const DataType input_type = node->attr().at("T").type();
(*swap_in_node->mutable_attr())["T"].set_type(input_type);
(*swap_out_node->mutable_attr())["T"].set_type(input_type);
- return std::make_pair(swap_out_node, swap_in_node);
+ *swap_pair = std::make_pair(swap_out_node, swap_in_node);
+
+ return Status::OK();
}
static int64 EstimateSize(const OpInfo::TensorProperties& t) {
@@ -568,9 +581,12 @@
max_trigger_time -= swap_info.time_to_swap;
std::map<Costs::NanoSeconds, const NodeDef*> candidates;
+ std::set<string> already_processed;
+
while (!possible_inputs.empty()) {
const string input_node_name = *possible_inputs.begin();
possible_inputs.erase(possible_inputs.begin());
+ already_processed.insert(input_node_name);
auto it1 = name_map.find(input_node_name);
if (it1 == name_map.end()) {
return nullptr;
@@ -579,7 +595,7 @@
// Don't jump over frames, since adding a control dependency from one frame
// to the next isn't supported. Don't go through branches, since we don't
// know whether they'll be executed or not.
- if (IsNextIteration(*input_node) || IsSwitch(*input_node) ||
+ if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
IsMerge(*input_node)) {
continue;
}
@@ -591,7 +607,10 @@
candidates[it2->second] = input_node;
} else {
for (const string& fanin : input_node->input()) {
- possible_inputs.insert(NodeName(fanin));
+ string name = NodeName(fanin);
+ if (already_processed.find(name) == already_processed.end()) {
+ possible_inputs.insert(name);
+ }
}
}
}
@@ -605,13 +624,31 @@
return nullptr;
}
+static bool IsSwappable(GraphView::InputPort input) {
+ const NodeDef& node = *input.node;
+
+ const OpDef* op_def;
+ if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+ return false;
+ }
+
+ DataType dtype;
+ if (!InputTypeForNode(*input.node, *op_def, input.port_id, &dtype).ok()) {
+ return false;
+ }
+
+ return !IsRefType(dtype);
+}
+
static void IdentifySwappingCandidates(Cluster* cluster,
const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphMemory memory(item);
const std::unordered_map<string, DeviceProperties>& devices =
cluster->GetDevices();
- if (!memory.InferStatically(devices).ok()) {
+ Status s = memory.InferStatically(devices);
+ if (!s.ok()) {
+ VLOG(1) << "Failed to infer memory usage: " << s.error_message();
return;
}
@@ -622,24 +659,36 @@
continue;
}
if (prop.memory_size() <= 0) {
+ VLOG(1) << "Peak memory usage unknown for device " << name;
continue;
}
const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
+
if (mem_usage.used_memory <= prop.memory_size()) {
continue;
}
int64 required_savings = mem_usage.used_memory - prop.memory_size();
// TODO(bsteiner): sort the tensors by how long they're live.
- std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
- if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) {
- return;
+ std::unordered_map<string, Costs::NanoSeconds> execution_times;
+ {
+ std::unordered_map<const NodeDef*, Costs::NanoSeconds>
+ tmp_execution_times;
+ if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times)
+ .ok()) {
+ return;
+ }
+ for (const auto& exec_time : tmp_execution_times) {
+ execution_times.emplace(exec_time.first->name(), exec_time.second);
+ }
}
+
GraphView graph(optimized_graph);
for (const auto& live_tensor : mem_usage.live_tensors) {
if (live_tensor.deallocation_time - live_tensor.allocation_time <=
Costs::Duration(1e6)) {
// Not enough time to swap.
+ VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
continue;
}
if (live_tensor.memory_used <= 1024) {
@@ -651,7 +700,10 @@
GraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
for (GraphView::InputPort input : graph.GetFanout(port)) {
- auto it = execution_times.find(input.node);
+ if (!IsSwappable(input)) {
+ continue;
+ }
+ auto it = execution_times.find(input.node->name());
if (it != execution_times.end()) {
if (it->second > execution_time) {
fanout_to_swap = input;
@@ -661,15 +713,23 @@
}
// Annotate the fanout to request the tensor to be swapped if it's not
// already been done.
- AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
bool found = false;
- for (int port_id : val.list().i()) {
- if (port_id == fanout_to_swap.port_id) {
- found = true;
- break;
+ if (!fanout_to_swap.node) {
+ continue;
+ }
+ auto it = fanout_to_swap.node->attr().find("_swap_to_host");
+ if (it != fanout_to_swap.node->attr().end()) {
+ const AttrValue& val = it->second;
+ for (int port_id : val.list().i()) {
+ if (port_id == fanout_to_swap.port_id) {
+ found = true;
+ break;
+ }
}
}
if (!found) {
+ AttrValue& val =
+ (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
val.mutable_list()->add_i(fanout_to_swap.port_id);
required_savings -= live_tensor.memory_used;
if (required_savings < 0) {
@@ -688,7 +748,8 @@
recomputation_targets_name_prefix_,
optimized_graph, item);
- if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) {
+ if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS &&
+ cluster != nullptr) {
IdentifySwappingCandidates(cluster, item, optimized_graph);
}
@@ -713,7 +774,6 @@
return Status::OK();
}
- {
// Estimate the size of the data to swap for each node.
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true));
@@ -730,7 +790,6 @@
// Let's assume we're going to swap over PCIe running at 16 GBps.
swap_info.time_to_swap = bytes_to_swap / 16;
}
- }
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
TF_RETURN_IF_ERROR(
@@ -743,7 +802,7 @@
for (auto& swap : nodes_to_swap) {
NodeDef* node = swap.first;
- SwapInfo& swap_info = swap.second;
+ const SwapInfo& swap_info = swap.second;
// Make sure the tensor isn't swapped back in right away: look for node that
// will execute just before we need to swap the data back, and add a control
@@ -755,8 +814,10 @@
}
// Swap all the tensors that are marked with the 'swap_to_host' attribute.
for (int input_id : swap_info.inputs_to_swap) {
- std::pair<NodeDef*, NodeDef*> swap_nodes =
- BuildSwapPair(node, input_id, optimized_graph);
+ std::pair<NodeDef*, NodeDef*> swap_nodes;
+ if (!BuildSwapPair(node, input_id, optimized_graph, &swap_nodes).ok()) {
+ continue;
+ }
*swap_nodes.first->add_input() = node->input(input_id);
*node->mutable_input(input_id) = swap_nodes.second->name();
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index 6fa4731..6448d1e 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -201,8 +201,16 @@
cpu_device.set_frequency(1000);
cpu_device.set_num_cores(4);
cpu_device.set_bandwidth(32);
+ DeviceProperties gpu_device;
+ gpu_device.set_type("GPU");
+ gpu_device.set_frequency(1000);
+ gpu_device.set_num_cores(24);
+ gpu_device.set_bandwidth(128);
+ gpu_device.set_memory_size(1024 * 1024);
+ gpu_device.mutable_environment()->insert({"architecture", "6"});
std::unordered_map<string, DeviceProperties> devices;
devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
+ devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
}
};
@@ -252,6 +260,74 @@
EXPECT_EQ("^c", swap_in.input(1));
}
+TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+ {128, 128, 8}, DT_FLOAT);
+ Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
+ Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
+ Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
+ Output axis = ops::Const(s.WithOpName("axis"), 0);
+ Output e =
+ ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"e"};
+
+ std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
+
+ MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
+ GraphDef output;
+ Status status = optimizer.Optimize(cluster.get(), item, &output);
+ TF_EXPECT_OK(status);
+
+ for (const auto& node : output.node()) {
+ if (node.name() == "e") {
+ EXPECT_TRUE(node.attr().count("_swap_to_host") > 0);
+ const AttrValue& val = node.attr().at("_swap_to_host");
+ EXPECT_TRUE(val.has_list());
+ std::set<int> inputs_to_swap;
+ for (int64 input_id : val.list().i()) {
+ inputs_to_swap.insert(input_id);
+ }
+ EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
+ }
+ }
+}
+
+TEST_F(MemoryOptimizerTest, UnswappableInputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+ {128, 128, 8}, DT_FLOAT);
+ Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
+ Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
+ Output index = ops::Const(s.WithOpName("index"), {0});
+ Output indices = ops::Tile(s.WithOpName("indices"), index, {128});
+ Output d =
+ ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), a, indices, c);
+ Output axis = ops::Const(s.WithOpName("axis"), 0);
+ Output e =
+ ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"e"};
+
+ std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
+
+ MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
+ GraphDef output;
+ Status status = optimizer.Optimize(cluster.get(), item, &output);
+ TF_EXPECT_OK(status);
+
+ for (const auto& node : output.node()) {
+ if (node.name() == "d") {
+ EXPECT_EQ(0, node.attr().count("_swap_to_host"));
+ }
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 6899767..6f7a0fd 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -164,7 +164,7 @@
buffer_element.value.size()));
for (size_t j = 0; j < buffer_element.value.size(); j++) {
TF_RETURN_IF_ERROR(writer->WriteTensor(
- strings::StrCat("buffer[", i, "][", j, "]"),
+ full_name(strings::StrCat("buffer[", i, "][", j, "]")),
buffer_element.value[j]));
}
}
@@ -201,7 +201,7 @@
for (size_t j = 0; j < value_size; j++) {
buffer_element.value.emplace_back();
TF_RETURN_IF_ERROR(reader->ReadTensor(
- strings::StrCat("buffer[", i, "][", j, "]"),
+ full_name(strings::StrCat("buffer[", i, "][", j, "]")),
&buffer_element.value.back()));
}
}
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
index abe3126..71af30e 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
@@ -40,21 +41,15 @@
}
Status SqliteQueryConnection::Close() {
- Status s;
- s.Update(stmt_.Close());
- s.Update(db_->Close());
- return s;
+ stmt_ = SqliteStatement();
+ db_.reset();
+ return Status::OK();
}
Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
- if (!stmt_) {
- Status s = PrepareQuery();
- if (!s.ok()) {
- return s;
- }
- }
- Status s = stmt_.Step(end_of_sequence);
+ if (!stmt_) TF_RETURN_IF_ERROR(PrepareQuery());
+ TF_RETURN_IF_ERROR(stmt_.Step(end_of_sequence));
if (!*end_of_sequence) {
for (int i = 0; i < column_count_; i++) {
DataType dt = output_types_[i];
@@ -63,64 +58,48 @@
out_tensors->emplace_back(std::move(tensor));
}
}
- return s;
+ return Status::OK();
}
Status SqliteQueryConnection::PrepareQuery() {
- stmt_ = db_->Prepare(query_);
- Status s = stmt_.status();
- if (s.ok()) {
- int column_count = stmt_.ColumnCount();
- if (column_count != output_types_.size()) {
- return errors::InvalidArgument(tensorflow::strings::Printf(
- "The number of columns in query (%d) must match the number of "
- "elements in output_types (%zu).",
- column_count, output_types_.size()));
- }
- column_count_ = column_count;
+ auto prep = db_->Prepare(query_);
+ TF_RETURN_IF_ERROR(prep.status());
+ int column_count = prep.ValueOrDie().ColumnCount();
+ if (column_count != output_types_.size()) {
+ return errors::InvalidArgument(tensorflow::strings::Printf(
+ "The number of columns in query (%d) must match the number of "
+ "elements in output_types (%zu).",
+ column_count, output_types_.size()));
}
- return s;
+ stmt_ = prep.ConsumeValueOrDie();
+ column_count_ = column_count;
+ return Status::OK();
}
void SqliteQueryConnection::FillTensorWithResultSetEntry(
const DataType& data_type, int column_index, Tensor* tensor) {
+#define CASE(T, M) \
+ case DataTypeToEnum<T>::value: \
+ tensor->scalar<T>()() = static_cast<T>(stmt_.M(column_index)); \
+ break;
+#define INT_CASE(T) CASE(T, ColumnInt)
+#define DOUBLE_CASE(T) CASE(T, ColumnDouble)
+#define STRING_CASE(T) CASE(T, ColumnString)
switch (data_type) {
- case DT_STRING:
- tensor->scalar<string>()() = stmt_.ColumnString(column_index);
- break;
- case DT_INT8:
- tensor->scalar<int8>()() =
- static_cast<int8>(stmt_.ColumnInt(column_index));
- break;
- case DT_INT16:
- tensor->scalar<int16>()() =
- static_cast<int16>(stmt_.ColumnInt(column_index));
- break;
- case DT_INT32:
- tensor->scalar<int32>()() =
- static_cast<int32>(stmt_.ColumnInt(column_index));
- break;
- case DT_INT64:
- tensor->scalar<int64>()() = stmt_.ColumnInt(column_index);
- break;
- case DT_UINT8:
- tensor->scalar<uint8>()() =
- static_cast<uint8>(stmt_.ColumnInt(column_index));
- break;
- case DT_UINT16:
- tensor->scalar<uint16>()() =
- static_cast<uint16>(stmt_.ColumnInt(column_index));
- break;
+ TF_CALL_int8(INT_CASE)
+ TF_CALL_uint8(INT_CASE)
+ TF_CALL_int16(INT_CASE)
+ TF_CALL_uint16(INT_CASE)
+ TF_CALL_int32(INT_CASE)
+ TF_CALL_uint32(INT_CASE)
+ TF_CALL_int64(INT_CASE)
+ TF_CALL_uint64(INT_CASE)
+ TF_CALL_float(DOUBLE_CASE)
+ TF_CALL_double(DOUBLE_CASE)
+ TF_CALL_string(STRING_CASE)
case DT_BOOL:
tensor->scalar<bool>()() = stmt_.ColumnInt(column_index) != 0;
break;
- case DT_FLOAT:
- tensor->scalar<float>()() =
- static_cast<float>(stmt_.ColumnDouble(column_index));
- break;
- case DT_DOUBLE:
- tensor->scalar<double>()() = stmt_.ColumnDouble(column_index);
- break;
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
default: {
LOG(FATAL)
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index facac10..f469f41 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -296,19 +296,21 @@
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
const Tensor* target;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
+ AttrValueMap attr_values = func_.attr();
+ AttrValue v;
const string& target_device =
DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()());
+ v.set_s(target_device);
+ AddAttr("_target", v, &attr_values);
FunctionLibraryRuntime* lib = ctx->function_library();
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library is provided."),
done);
- AttrValueMap attr_values = func_.attr();
FunctionLibraryRuntime::Handle handle;
- OP_REQUIRES_OK_ASYNC(ctx,
- lib->Instantiate(func_.name(), AttrSlice(&attr_values),
- {target_device}, &handle),
- done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle),
+ done);
OpInputList arguments;
OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 7ccaef9..9b4dca8 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -575,27 +575,6 @@
bool is_training_;
};
-namespace {
-
-template <typename Device>
-void FillZeros(Tensor* t);
-
-#if GOOGLE_CUDA
-template <>
-void FillZeros<GPUDevice>(Tensor* t) {
- cudaMemset(const_cast<char*>(t->tensor_data().data()), 0,
- t->tensor_data().size());
-}
-#endif
-
-template <>
-void FillZeros<CPUDevice>(Tensor* t) {
- memset(const_cast<char*>(t->tensor_data().data()), 0,
- t->tensor_data().size());
-}
-
-} // namespace
-
template <typename Device, typename T, typename U>
class FusedBatchNormGradOp : public OpKernel {
public:
@@ -659,11 +638,12 @@
Tensor* placeholder_1 = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(3, TensorShape({}), &placeholder_1));
- FillZeros<Device>(placeholder_1);
+ functor::SetZeroFunctor<Device, float> f;
+ f(context->eigen_device<Device>(), placeholder_1->flat<U>());
Tensor* placeholder_2 = nullptr;
OP_REQUIRES_OK(
context, context->allocate_output(4, TensorShape({}), &placeholder_2));
- FillZeros<Device>(placeholder_2);
+ f(context->eigen_device<Device>(), placeholder_2->flat<U>());
// If input is empty, set gradients w.r.t scale/offset to zero.
if (x.shape().num_elements() == 0) {
diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc
index a73581f..d148c9f 100644
--- a/tensorflow/core/kernels/quantization_utils_test.cc
+++ b/tensorflow/core/kernels/quantization_utils_test.cc
@@ -743,7 +743,8 @@
void TestDivide64x2Pow(int64 val, int64 ref) {
const int64x2_t val_64x2 = vmovq_n_s64(val);
const int64x2_t ret = Divide64x2Pow<POW>(val_64x2);
- int64 rets[2];
+ // TODO(b/70947959) Change back to int64 when possible
+ int64_t rets[2];
vst1q_s64(rets, ret);
EXPECT_EQ(rets[0], ref);
EXPECT_EQ(rets[1], ref);
@@ -754,7 +755,8 @@
void TestDivide64x2PowRound(int64 val, int64 ref) {
const int64x2_t val_64x2 = vmovq_n_s64(val);
const int64x2_t shifted = Divide64x2PowRound<POW>(val_64x2);
- int64 rets[2];
+ // TODO(b/70947959) Change back to int64 when possible
+ int64_t rets[2];
vst1q_s64(rets, shifted);
EXPECT_EQ(rets[0], ref) << "in = " << val << ", " << POW
<< ", act = " << rets[0] << ", ref = " << ref;
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index f092afe..a86c046 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -67,9 +67,8 @@
SummaryWriterInterface* s;
auto db = Sqlite::Open(db_uri);
OP_REQUIRES_OK(ctx, db.status());
- db.ValueOrDie()->UseWriteAheadLogWithReducedDurabilityIfPossible();
OP_REQUIRES_OK(
- ctx, CreateSummaryDbWriter(std::move(db.ValueOrDie()), experiment_name,
+ ctx, CreateSummaryDbWriter(db.ConsumeValueOrDie(), experiment_name,
run_name, user_name, ctx->env(), &s));
OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
}
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index cca6d0e..66aee2d 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -336,8 +336,7 @@
tensor_array->HasIdenticalElementShapes(), false /* dynamic_size */,
true /* multiple_writes_aggregate */, true /* is_grad */,
marked_size /* marked_size */, true /* close_after_read */);
- TF_RETURN_IF_ERROR((*ret)->CopyShapesFrom(tensor_array));
- return Status::OK();
+ return (*ret)->CopyShapesFrom(tensor_array);
};
Status s = rm->LookupOrCreate<TensorArray>(
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 7824702..e64b27b 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -83,65 +83,100 @@
}
}
- auto Tin = input.shaped<T, 3>(new_sizes);
-
Tensor* idx = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
- 1, TensorShape({Tin.dimension(1)}), &idx));
+ 1, TensorShape({new_sizes[1]}), &idx));
auto idx_vec = idx->template vec<TIndex>();
- auto hash_fn = [&Tin](const int64& key) -> unsigned long {
- size_t h = 0;
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
- h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
+ int64 uniq_size;
+ if (new_sizes[0] == 1 && new_sizes[2] == 1) {
+ // Specialized and faster implementation when unique is run over single
+ // elements. Here we put T directly into the map rather than ints pointing
+ // to them as in the general case.
+ auto Tin = input.flat<T>();
+ const int64 N = static_cast<int64>(Tin.size());
+
+ std::unordered_map<T, TIndex> uniq;
+ uniq.reserve(2 * N);
+ for (int64 i = 0, j = 0; i < N; ++i) {
+ auto it = uniq.insert(std::make_pair(Tin(i), j));
+ idx_vec(i) = it.first->second;
+ if (it.second) {
+ ++j;
}
}
- return h;
- };
- auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
- if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
- return false;
+ uniq_size = static_cast<int64>(uniq.size());
+ TensorShape output_shape(input.shape());
+ output_shape.set_dim(axis, uniq_size);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape, &output));
+ auto Tout = output->flat<T>();
+
+ for (auto it : uniq) {
+ Tout(it.second) = it.first;
+ }
+ } else {
+ // General implementation when unique is run over multiple elements.
+ auto Tin = input.shaped<T, 3>(new_sizes);
+
+ auto hash_fn = [&Tin](const int64& key) {
+ size_t h = 0;
+ for (int64 i = 0; i < Tin.dimension(0); i++) {
+ for (int64 j = 0; j < Tin.dimension(2); j++) {
+ h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
}
}
+ return h;
+ };
+
+ auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
+ for (int64 i = 0; i < Tin.dimension(0); i++) {
+ for (int64 j = 0; j < Tin.dimension(2); j++) {
+ if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
+ return false;
+ }
+ }
+ }
+ return true;
+ };
+
+ std::unordered_map<int64, int64, decltype(hash_fn), decltype(equal_to_fn)>
+ uniq(0, hash_fn, equal_to_fn);
+
+ uniq.reserve(2 * Tin.dimension(1));
+
+ for (int64 i = 0, j = 0; i < Tin.dimension(1); ++i) {
+ auto it = uniq.insert(std::make_pair(i, j));
+ idx_vec(i) = it.first->second;
+ if (it.second) {
+ ++j;
+ }
}
- return true;
- };
- std::unordered_map<int64, int64, decltype(hash_fn), decltype(equal_to_fn)>
- uniq(0, hash_fn, equal_to_fn);
+ uniq_size = static_cast<int64>(uniq.size());
+ new_sizes[1] = uniq_size;
+ TensorShape output_shape(input.shape());
+ output_shape.set_dim(axis, uniq_size);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape, &output));
+ auto Tout = output->shaped<T, 3>(new_sizes);
- uniq.reserve(2 * Tin.dimension(1));
-
- for (int64 i = 0, j = 0; i < Tin.dimension(1); ++i) {
- auto it = uniq.insert(std::make_pair(i, j));
- idx_vec(i) = it.first->second;
- if (it.second) {
- ++j;
+ for (auto it : uniq) {
+ Tout.chip(it.second, 1) = Tin.chip(it.first, 1);
}
}
- int64 uniq_size = static_cast<int64>(uniq.size());
- new_sizes[1] = uniq_size;
- TensorShape output_shape(input.shape());
- output_shape.set_dim(axis, uniq_size);
- Tensor* output = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
- auto Tout = output->shaped<T, 3>(new_sizes);
-
- for (auto it : uniq) {
- Tout.chip(it.second, 1) = Tin.chip(it.first, 1);
- }
-
if (num_outputs() > 2) {
+ Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
2, TensorShape({uniq_size}), &output));
auto count_output_vec = output->template vec<TIndex>();
count_output_vec.setZero();
- for (int64 i = 0; i < Tin.dimension(1); ++i) {
+ const int N = idx_vec.size();
+ for (int64 i = 0; i < N; ++i) {
count_output_vec(idx_vec(i))++;
}
}
diff --git a/tensorflow/core/lib/db/sqlite.cc b/tensorflow/core/lib/db/sqlite.cc
index b0a9e2f..76bf778 100644
--- a/tensorflow/core/lib/db/sqlite.cc
+++ b/tensorflow/core/lib/db/sqlite.cc
@@ -14,224 +14,257 @@
==============================================================================*/
#include "tensorflow/core/lib/db/sqlite.h"
-#include "tensorflow/core/lib/io/record_reader.h"
-#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
extern "C" int sqlite3_snapfn_init(sqlite3*, const char**, const void*);
namespace tensorflow {
namespace {
-void ExecuteOrLog(Sqlite* db, const char* sql) {
- Status s = db->Prepare(sql).StepAndReset();
- if (!s.ok()) {
- LOG(WARNING) << s.ToString();
+error::Code GetTfErrorCode(int code) {
+ // See: https://sqlite.org/rescode.html
+ switch (code & 0xff) {
+ case SQLITE_OK: // Successful result
+ case SQLITE_ROW: // Step has another row ready
+ case SQLITE_DONE: // Step has finished executing
+ return error::OK;
+ case SQLITE_ABORT: // Callback routine requested an abort
+ return error::ABORTED;
+ case SQLITE_READONLY: // Attempt to write a readonly database
+ case SQLITE_MISMATCH: // Data type mismatch
+ return error::FAILED_PRECONDITION;
+ case SQLITE_MISUSE: // Library used incorrectly
+ case SQLITE_INTERNAL: // Internal logic error in SQLite
+ return error::INTERNAL;
+ case SQLITE_RANGE: // 2nd parameter to sqlite3_bind out of range
+ return error::OUT_OF_RANGE;
+ case SQLITE_CANTOPEN: // Unable to open the database file
+ case SQLITE_CONSTRAINT: // Abort due to constraint violation
+ case SQLITE_NOTFOUND: // Unknown opcode or statement parameter name
+ case SQLITE_NOTADB: // File opened that is not a database file
+ return error::INVALID_ARGUMENT;
+ case SQLITE_CORRUPT: // The database disk image is malformed
+ return error::DATA_LOSS;
+ case SQLITE_AUTH: // Authorization denied
+ case SQLITE_PERM: // Access permission denied
+ return error::PERMISSION_DENIED;
+ case SQLITE_FULL: // Insertion failed because database is full
+ case SQLITE_TOOBIG: // String or BLOB exceeds size limit
+ case SQLITE_NOLFS: // Uses OS features not supported on host
+ return error::RESOURCE_EXHAUSTED;
+ case SQLITE_BUSY: // The database file is locked
+ case SQLITE_LOCKED: // A table in the database is locked
+ case SQLITE_PROTOCOL: // Database lock protocol error
+ case SQLITE_NOMEM: // Out of heap or perhaps lookaside memory
+ return error::UNAVAILABLE;
+ case SQLITE_INTERRUPT: // Operation terminated by sqlite3_interrupt
+ return error::CANCELLED;
+ case SQLITE_ERROR: // SQL error or missing database
+ case SQLITE_IOERR: // Some kind of disk I/O error occurred
+ case SQLITE_SCHEMA: // The database schema changed
+ default:
+ return error::UNKNOWN;
}
}
-string ExecuteOrEmpty(Sqlite* db, const char* sql) {
- auto stmt = db->Prepare(sql);
- bool is_done = false;
- if (stmt.Step(&is_done).ok() && !is_done) {
- return stmt.ColumnString(0);
+template <typename... Args>
+Status PrintfStatus(int rc, const char* fmt, Args&&... args) {
+ return {GetTfErrorCode(rc),
+ strings::Printf(fmt, std::forward<Args>(args)...)};
+}
+
+Status AsStatus(Sqlite* db, int rc) EXCLUSIVE_LOCKS_REQUIRED(*db) {
+ if (TF_PREDICT_TRUE(rc == SQLITE_OK)) return Status::OK();
+ return {GetTfErrorCode(rc), db->errmsg()};
+}
+
+sqlite3_stmt* PrepareRawOrDie(sqlite3* db, const char* sql) {
+ sqlite3_stmt* stmt = nullptr;
+ int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr);
+ CHECK_EQ(SQLITE_OK, rc) << sql;
+ return stmt;
+}
+
+Status SetEnvPragmaActual(Sqlite* db, const char* pragma, const char* var) {
+ const char* value = std::getenv(var);
+ if (value == nullptr || *value == '\0') return Status::OK();
+ for (const char* p = value; *p != '\0'; ++p) {
+ if (!(('0' <= *p && *p <= '9') || *p == '-' ||
+ ('A' <= *p && *p <= 'Z') ||
+ ('a' <= *p && *p <= 'z'))) {
+ return errors::InvalidArgument("Illegal character");
+ }
}
- return "";
+ // We can't use Bind*() for pragmas.
+ auto stmt = db->Prepare(strings::StrCat("PRAGMA ", pragma, "=", value));
+ TF_RETURN_IF_ERROR(stmt.status());
+ bool unused_done;
+ return stmt.ValueOrDie().Step(&unused_done);
+}
+
+Status EnvPragma(Sqlite* db, const char* pragma, const char* var) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(SetEnvPragmaActual(db, pragma, var),
+ "getenv(", var, ")");
+ return Status::OK();
}
} // namespace
/* static */
-xla::StatusOr<std::shared_ptr<Sqlite>> Sqlite::Open(const string& uri) {
+xla::StatusOr<std::shared_ptr<Sqlite>> Sqlite::Open(string path, int flags) {
+ flags |= SQLITE_OPEN_PRIVATECACHE;
sqlite3* sqlite = nullptr;
- TF_RETURN_IF_ERROR(MakeStatus(sqlite3_open(uri.c_str(), &sqlite)));
- CHECK_EQ(SQLITE_OK, sqlite3_snapfn_init(sqlite, nullptr, nullptr));
- Sqlite* db = new Sqlite(sqlite, uri);
- // This is the SQLite default since 2016. However it's good to set
- // this anyway, since we might get linked against an older version of
- // the library, and it's pretty much impossible to change later.
- ExecuteOrLog(db, "PRAGMA page_size=4096");
- return std::shared_ptr<Sqlite>(db);
-}
-
-/* static */ Status Sqlite::MakeStatus(int resultCode) {
- // See: https://sqlite.org/rescode.html
- switch (resultCode & 0xff) {
- case SQLITE_OK:
- case SQLITE_ROW: // sqlite3_step() has another row ready
- case SQLITE_DONE: // sqlite3_step() has finished executing
- return Status::OK();
- case SQLITE_ABORT: // Callback routine requested an abort
- return errors::Aborted(sqlite3_errstr(resultCode));
- case SQLITE_READONLY: // Attempt to write a readonly database
- case SQLITE_MISMATCH: // Data type mismatch
- return errors::FailedPrecondition(sqlite3_errstr(resultCode));
- case SQLITE_MISUSE: // Library used incorrectly
- case SQLITE_INTERNAL: // Internal logic error in SQLite
- return errors::Internal(sqlite3_errstr(resultCode));
- case SQLITE_RANGE: // 2nd parameter to sqlite3_bind out of range
- return errors::OutOfRange(sqlite3_errstr(resultCode));
- case SQLITE_CANTOPEN: // Unable to open the database file
- case SQLITE_CONSTRAINT: // Abort due to constraint violation
- case SQLITE_NOTFOUND: // Unknown opcode or statement parameter name
- case SQLITE_NOTADB: // File opened that is not a database file
- return errors::InvalidArgument(sqlite3_errstr(resultCode));
- case SQLITE_CORRUPT: // The database disk image is malformed
- return errors::DataLoss(sqlite3_errstr(resultCode));
- case SQLITE_AUTH: // Authorization denied
- case SQLITE_PERM: // Access permission denied
- return errors::PermissionDenied(sqlite3_errstr(resultCode));
- case SQLITE_FULL: // Insertion failed because database is full
- case SQLITE_TOOBIG: // String or BLOB exceeds size limit
- case SQLITE_NOLFS: // Uses OS features not supported on host
- return errors::ResourceExhausted(sqlite3_errstr(resultCode));
- case SQLITE_BUSY: // The database file is locked
- case SQLITE_LOCKED: // A table in the database is locked
- case SQLITE_PROTOCOL: // Database lock protocol error
- case SQLITE_NOMEM: // A malloc() failed
- return errors::Unavailable(sqlite3_errstr(resultCode));
- case SQLITE_INTERRUPT: // Operation terminated by sqlite3_interrupt
- return errors::Cancelled(sqlite3_errstr(resultCode));
- case SQLITE_ERROR: // SQL error or missing database
- case SQLITE_IOERR: // Some kind of disk I/O error occurred
- case SQLITE_SCHEMA: // The database schema changed
- default:
- return errors::Unknown(sqlite3_errstr(resultCode));
+ int rc = sqlite3_open_v2(path.c_str(), &sqlite, flags, nullptr);
+ if (rc != SQLITE_OK) {
+ return PrintfStatus(rc, "Sqlite::Open(%s) failed: %s", path.c_str(),
+ sqlite3_errstr(rc));
}
+ CHECK_EQ(SQLITE_OK, sqlite3_extended_result_codes(sqlite, 1));
+ CHECK_EQ(SQLITE_OK, sqlite3_snapfn_init(sqlite, nullptr, nullptr));
+ // Prepare these tiny privileged statements for SqliteTransaction
+ // so it can do less work, particularly in its constructor, per
+ // Google C++ Style.
+ sqlite3_stmt* begin = PrepareRawOrDie(sqlite, "BEGIN");
+ sqlite3_stmt* commit = PrepareRawOrDie(sqlite, "COMMIT");
+ sqlite3_stmt* rollback = PrepareRawOrDie(sqlite, "ROLLBACK");
+ auto r = std::shared_ptr<Sqlite>(
+ new Sqlite(sqlite, std::move(path), begin, commit, rollback));
+ r->self_ = std::weak_ptr<Sqlite>(r);
+ Sqlite* db = r.get();
+ // TensorFlow is designed to work well in all SQLite modes. However
+ // users might find tuning some these pragmas rewarding, depending on
+ // various considerations.
+ TF_RETURN_IF_ERROR(EnvPragma(db, "secure_delete", "TF_SQLITE_SECURE_DELETE"));
+ TF_RETURN_IF_ERROR(EnvPragma(db, "page_size", "TF_SQLITE_PAGE_SIZE"));
+ TF_RETURN_IF_ERROR(EnvPragma(db, "journal_mode", "TF_SQLITE_JOURNAL_MODE"));
+ TF_RETURN_IF_ERROR(EnvPragma(db, "synchronous", "TF_SQLITE_SYNCHRONOUS"));
+ TF_RETURN_IF_ERROR(EnvPragma(db, "mmap_size", "TF_SQLITE_MMAP_SIZE"));
+ TF_RETURN_IF_ERROR(EnvPragma(db, "locking_mode", "TF_SQLITE_LOCKING_MODE"));
+ TF_RETURN_IF_ERROR(EnvPragma(db, "cache_size", "TF_SQLITE_CACHE_SIZE"));
+ TF_RETURN_IF_ERROR(EnvPragma(db, "auto_vacuum", "TF_SQLITE_AUTO_VACUUM"));
+ return r;
}
-Sqlite::Sqlite(sqlite3* db, const string& uri) : db_(db), uri_(uri) {}
-
Sqlite::~Sqlite() {
- // close_v2 doesn't care if a stmt hasn't been GC'd yet
- int rc = sqlite3_close_v2(db_);
- if (rc != SQLITE_OK) {
- LOG(ERROR) << "destruct sqlite3: " << MakeStatus(rc);
- }
+ sqlite3_finalize(rollback_);
+ sqlite3_finalize(commit_);
+ sqlite3_finalize(begin_);
+ CHECK_EQ(SQLITE_OK, sqlite3_close(db_));
}
-Status Sqlite::Close() {
- if (db_ == nullptr) {
- return Status::OK();
- }
- // If Close is explicitly called, ordering must be correct.
- Status s = MakeStatus(sqlite3_close(db_));
- if (s.ok()) {
- db_ = nullptr;
- }
- return s;
-}
-
-void Sqlite::UseWriteAheadLogWithReducedDurabilityIfPossible() {
- // TensorFlow summaries are intensively write-heavy, cf. most apps.
- // This pragma loves writes and means that TensorBoard can read the
- // database even as the training job inserts stuff. In other words,
- // this makes SQLite almost as powerful as MySQL or PostgreSQL.
- // https://www.sqlite.org/wal.html
- string journal = ExecuteOrEmpty(this, "PRAGMA journal_mode=wal");
- if (journal != "wal") {
- LOG(WARNING) << "Failed to set journal_mode=wal because SQLite wants "
- << uri_ << " to be in '" << journal << "' mode, which might "
- << "be bad since WAL is important for the performance of "
- << "write-intensive apps. This might only happen for memory "
- << "databases or old versions of SQLite, but is definitely "
- << "worth fixing if that's not the case";
- } else {
- // This setting means we might lose transactions due to power loss,
- // but the database can't become corrupted. In exchange, we get the
- // the performance of a NoSQL database. This is a trade-off most data
- // scientists would consider acceptable.
- // https://www.sqlite.org/pragma.html#pragma_synchronous
- ExecuteOrLog(this, "PRAGMA synchronous=NORMAL");
- }
-}
-
-SqliteStatement Sqlite::Prepare(const string& sql) {
+xla::StatusOr<SqliteStatement> Sqlite::Prepare(const StringPiece& sql) {
+ SqliteLock lock(*this);
sqlite3_stmt* stmt = nullptr;
- int rc = sqlite3_prepare_v2(db_, sql.c_str(), sql.size() + 1, &stmt, nullptr);
- if (rc == SQLITE_OK) {
- return {stmt, SQLITE_OK, std::unique_ptr<string>(nullptr)};
- } else {
- return {nullptr, rc, std::unique_ptr<string>(new string(sql))};
+ int rc = sqlite3_prepare_v2(db_, sql.data(), static_cast<int>(sql.size()),
+ &stmt, nullptr);
+ if (rc != SQLITE_OK) {
+ return PrintfStatus(rc, "Prepare() failed: %s: %.*s", errmsg(), sql.size(),
+ sql.data());
+ }
+ return SqliteStatement(stmt, self_.lock());
+}
+
+Status SqliteStatement::Step(bool* is_done) {
+ DCHECK(stmt_ != nullptr);
+ if (TF_PREDICT_FALSE(bind_error_ != SQLITE_OK)) {
+ *is_done = true;
+ return PrintfStatus(bind_error_, "Bind(%d) failed: %s: %s",
+ bind_error_parameter_, sqlite3_errstr(bind_error_),
+ sql());
+ }
+ SqliteLock lock(*db_);
+ int rc = sqlite3_step(stmt_);
+ switch (rc) {
+ case SQLITE_ROW:
+ *is_done = false;
+ return Status::OK();
+ case SQLITE_DONE:
+ *is_done = true;
+ return Status::OK();
+ default:
+ *is_done = true;
+ return PrintfStatus(rc, "Step() failed: %s: %s", db_->errmsg(), sql());
}
}
-Status SqliteStatement::status() const {
- Status s = Sqlite::MakeStatus(error_);
- if (!s.ok()) {
- if (stmt_ != nullptr) {
- errors::AppendToMessage(&s, sqlite3_sql(stmt_));
- } else {
- errors::AppendToMessage(&s, *prepare_error_sql_);
- }
+bool SqliteStatement::StepOrDie() {
+ bool is_done;
+ TF_CHECK_OK(Step(&is_done));
+ return !is_done;
+}
+
+Status SqliteStatement::StepOnce() {
+ bool is_done;
+ TF_RETURN_IF_ERROR(Step(&is_done));
+ if (TF_PREDICT_FALSE(is_done)) {
+ return errors::Internal("No rows returned: ", sql());
}
+ return Status::OK();
+}
+
+const SqliteStatement& SqliteStatement::StepOnceOrDie() {
+ TF_CHECK_OK(StepOnce());
+ return *this;
+}
+
+Status SqliteStatement::StepAndReset() {
+ bool is_done;
+ Status s = Step(&is_done);
+ if (TF_PREDICT_FALSE(s.ok() && !is_done)) {
+ s = errors::Internal("Unexpected row: ", sql());
+ }
+ Reset();
return s;
}
-void SqliteStatement::CloseOrLog() {
- if (stmt_ != nullptr) {
- int rc = sqlite3_finalize(stmt_);
- if (rc != SQLITE_OK) {
- LOG(ERROR) << "destruct sqlite3_stmt: " << Sqlite::MakeStatus(rc);
- }
- stmt_ = nullptr;
- }
-}
-
-Status SqliteStatement::Close() {
- if (stmt_ == nullptr) {
- return Status::OK();
- }
- int rc = sqlite3_finalize(stmt_);
- if (rc == SQLITE_OK) {
- stmt_ = nullptr;
- }
- Update(rc);
- return status();
-}
+void SqliteStatement::StepAndResetOrDie() { TF_CHECK_OK(StepAndReset()); }
void SqliteStatement::Reset() {
if (TF_PREDICT_TRUE(stmt_ != nullptr)) {
sqlite3_reset(stmt_);
- sqlite3_clear_bindings(stmt_); // not nullptr friendly
+ sqlite3_clear_bindings(stmt_);
}
- error_ = SQLITE_OK;
+ bind_error_ = SQLITE_OK;
+ size_ = 0;
}
-Status SqliteStatement::Step(bool* isDone) {
- if (TF_PREDICT_FALSE(error_ != SQLITE_OK)) {
- *isDone = true;
- return status();
- }
- int rc = sqlite3_step(stmt_);
- switch (rc) {
- case SQLITE_ROW:
- *isDone = false;
- return Status::OK();
- case SQLITE_DONE:
- *isDone = true;
- return Status::OK();
- default:
- *isDone = true;
- error_ = rc;
- return status();
+SqliteTransaction::SqliteTransaction(Sqlite& db) : db_(&db) {
+ sqlite3_mutex_enter(sqlite3_db_mutex(db_->db_));
+ CHECK(!db_->is_in_transaction_);
+ db_->is_in_transaction_ = true;
+ Begin();
+}
+
+SqliteTransaction::~SqliteTransaction() {
+ // Rollback should only return an error if there's no transaction.
+ // Since the API performs auto-rollbacks in some cases, we ignore.
+ sqlite3_step(db_->rollback_);
+ sqlite3_reset(db_->rollback_);
+ sqlite3_reset(db_->begin_);
+ db_->is_in_transaction_ = false;
+ sqlite3_mutex_leave(sqlite3_db_mutex(db_->db_));
+}
+
+void SqliteTransaction::Begin() {
+ // This shouldn't allocate memory or perform I/O. All it does is
+ // execute OP_AutoCommit(0, 0) a.k.a. BEGIN DEFERRED which flips
+ // the sqlite3::autoCommit bit.
+ if (sqlite3_step(db_->begin_) != SQLITE_DONE) {
+ // It shouldn't be possible for this to fail since we already
+ // performed the reentrancy check.
+ LOG(FATAL) << "BEGIN failed: " << sqlite3_errmsg(db_->db_);
}
}
-Status SqliteStatement::StepAndReset() {
- if (TF_PREDICT_FALSE(error_ != SQLITE_OK)) {
- return status();
- }
- Status s;
- int rc = sqlite3_step(stmt_);
+Status SqliteTransaction::Commit() {
+ int rc = sqlite3_step(db_->commit_);
if (rc != SQLITE_DONE) {
- if (rc == SQLITE_ROW) {
- s.Update(errors::Internal("unexpected sqlite row"));
- } else {
- s.Update(Sqlite::MakeStatus(rc));
- }
+ return PrintfStatus(rc, "COMMIT failed: %s", sqlite3_errmsg(db_->db_));
}
- Reset();
- return s;
+ sqlite3_reset(db_->commit_);
+ sqlite3_reset(db_->begin_);
+ Begin();
+ return Status::OK();
}
} // namespace tensorflow
diff --git a/tensorflow/core/lib/db/sqlite.h b/tensorflow/core/lib/db/sqlite.h
index 12840bd..49a989a 100644
--- a/tensorflow/core/lib/db/sqlite.h
+++ b/tensorflow/core/lib/db/sqlite.h
@@ -17,155 +17,212 @@
#include <cstddef>
#include <memory>
+#include <mutex>
#include <utility>
#include "sqlite3.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
+/// TensorFlow SQLite Veneer
+///
+/// - Memory safety
+/// - Less boilerplate
+/// - Removes deprecated stuff
+/// - Pretends UTF16 doesn't exist
+/// - Transaction compile-time safety
+/// - Statically loads our native extensions
+/// - Error reporting via tensorflow::Status et al.
+///
+/// SQLite>=3.8.2 needs to be supported until April 2019, which is when
+/// Ubuntu 14.04 LTS becomes EOL.
+
namespace tensorflow {
+class SqliteLock;
class SqliteStatement;
+class SqliteTransaction;
/// \brief SQLite connection object.
///
-/// This class is a thin wrapper around `sqlite3` that makes it easier
-/// and safer to use SQLite in the TensorFlow C++ codebase. It removes
-/// deprecated APIs, improves the safety of others, adds helpers, and
-/// pretends UTF16 doesn't exist.
+/// The SQLite connection is closed automatically by the destructor.
+/// Reference counting ensures that happens after its statements are
+/// destructed.
///
-/// Instances are thread safe, with the exception of Close().
-class Sqlite {
+/// This class offers the same thread safety behaviors and guarantees
+/// as the SQLite API itself.
+///
+/// This veneer uses auto-commit mode by default, which means a 4ms
+/// fsync() happens after every write unless a SqliteTransaction is
+/// used or WAL mode is enabled beforehand.
+class LOCKABLE Sqlite {
public:
- /// \brief Opens SQLite database file.
- ///
- /// The `uri` parameter can be a filename, or a proper URI like
- /// `file:/tmp/tf.sqlite?mode=ro&cache=private`. It can also be
- /// `file::memory:` for testing.
- ///
- /// See https://sqlite.org/c3ref/open.html
- static xla::StatusOr<std::shared_ptr<Sqlite>> Open(const string& uri);
-
- /// \brief Makes tensorflow::Status for SQLite result code.
- ///
- /// See https://sqlite.org/rescode.html
- static Status MakeStatus(int resultCode);
-
- /// \brief Destroys object and frees resources.
- ///
- /// This will free the underlying object if Close was not called. If
- /// an error code is returned then it will be logged.
- ///
- /// Note: Unlike Close() this destructor maps to sqlite3_close_v2(),
- /// which is lax about ordering and GC friendly.
+ /// \brief Closes SQLite connection, which can take milliseconds.
~Sqlite();
- /// \brief Frees underlying SQLite object.
+ /// \brief Opens SQLite database file.
///
- /// Unlike the destructor, all SqliteStatement objects must be closed
- /// beforehand. This is a no-op if already closed
- Status Close();
-
- /// \brief Enables WAL mode with less fsync or log a warning.
+ /// Notes on a few of the flags:
///
- /// The synchronous pragma is only set to NORMAL if WAL mode was
- /// successfully enabled. This must be called immediately after
- /// creating the object.
- void UseWriteAheadLogWithReducedDurabilityIfPossible();
+ /// - SQLITE_OPEN_READONLY: Allowed if no WAL journal is active.
+ /// - SQLITE_OPEN_SHAREDCACHE: Will be ignored because this veneer
+ /// doesn't support the unlock notify API.
+ /// - SQLITE_OPEN_NOMUTEX: Means access to this connection MUST be
+ /// serialized by the caller in accordance with the same contracts
+ /// implemented by this API.
+ ///
+ /// This function sets PRAGMA values from TF_SQLITE_* environment
+ /// variables. See sqlite.cc to learn more.
+ static xla::StatusOr<std::shared_ptr<Sqlite>> Open(string path, int flags);
+ static xla::StatusOr<std::shared_ptr<Sqlite>> Open(string path) {
+ return Open(std::move(path), SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE);
+ }
+ static std::shared_ptr<Sqlite> OpenOrDie(string path, int flags) {
+ return Open(std::move(path), flags).ValueOrDie();
+ }
+ static std::shared_ptr<Sqlite> OpenOrDie(string path) {
+ return Open(std::move(path)).ValueOrDie();
+ }
/// \brief Creates SQLite statement.
///
- /// Call result.status() to determine whether or not this operation
- /// failed. It is also possible to punt the error checking to after
- /// the values have been binded and Step() or ExecuteWriteQuery() is
- /// called.
- SqliteStatement Prepare(const string& sql);
+ /// If sql references tables then system calls taking microseconds
+ /// are needed and failure can happen on schema change. Otherwise
+ /// this should only fail on syntax error.
+ xla::StatusOr<SqliteStatement> Prepare(const StringPiece& sql);
+ SqliteStatement PrepareOrDie(const StringPiece& sql);
+
+ /// \brief Returns extended result code of last error.
+ ///
+ /// If the most recent API call was successful, the result is
+ /// undefined. The legacy result code can be obtained by saying
+ /// errcode() & 0xff.
+ int errcode() const EXCLUSIVE_LOCKS_REQUIRED(this) {
+ return sqlite3_extended_errcode(db_);
+ }
+
+ /// \brief Returns pointer to current error message state.
+ const char* errmsg() const EXCLUSIVE_LOCKS_REQUIRED(this) {
+ return sqlite3_errmsg(db_);
+ }
+
+ /// \brief Returns rowid assigned to last successful insert.
+ int64 last_insert_row_id() const EXCLUSIVE_LOCKS_REQUIRED(this) {
+ return sqlite3_last_insert_rowid(db_);
+ }
private:
- explicit Sqlite(sqlite3* db, const string& uri);
- sqlite3* db_;
- string uri_;
+ friend class SqliteLock;
+ friend class SqliteStatement;
+ friend class SqliteTransaction;
+
+ Sqlite(sqlite3* db, const string path, sqlite3_stmt* begin,
+ sqlite3_stmt* commit, sqlite3_stmt* rollback) noexcept
+ : db_(db),
+ path_(std::move(path)),
+ begin_(begin),
+ commit_(commit),
+ rollback_(rollback) {}
+
+ sqlite3* const db_;
+ const string path_;
+ sqlite3_stmt* const begin_;
+ sqlite3_stmt* const commit_;
+ sqlite3_stmt* const rollback_;
+ bool is_in_transaction_ = false;
+ std::weak_ptr<Sqlite> self_; // so prepare can pass to statements
+
TF_DISALLOW_COPY_AND_ASSIGN(Sqlite);
};
-/// \brief SQLite prepared statement cursor object.
+/// \brief SQLite prepared statement.
///
-/// This class tracks error state internally, like Status::Update.
+/// Instances can only be shared between threads if caller serializes
+/// access from first Bind*() to *Reset().
///
-/// Instances of this class are not thread safe.
+/// When reusing a statement in a loop, be certain to not have jumps
+/// betwixt Bind*() and *Reset().
class SqliteStatement {
public:
- /// \brief Constructs empty statement that should be assigned later.
- SqliteStatement() : stmt_(nullptr), error_(SQLITE_OK) {}
+ /// \brief Initializes an empty statement to be assigned later.
+ SqliteStatement() noexcept = default;
- /// \brief Empties object and finalizes statement if needed.
- ~SqliteStatement() { CloseOrLog(); }
+ /// \brief Finalizes statement.
+ ///
+ /// This can take milliseconds if it was blocking the Sqlite
+ /// connection object from being freed.
+ ~SqliteStatement() { /* ignore */ sqlite3_finalize(stmt_); }
- /// \brief Move constructor, after which <other> should not be used.
- SqliteStatement(SqliteStatement&& other);
-
- /// \brief Move assignment, after which <other> should not be used.
- SqliteStatement& operator=(SqliteStatement&& other);
-
- /// \brief Returns true if statement is not empty.
+ /// \brief Returns true if statement is initialized.
explicit operator bool() const { return stmt_ != nullptr; }
- /// \brief Returns SQLite result code state.
- ///
- /// This will be SQLITE_OK unless an error happened. If multiple
- /// errors happened, only the first error code will be returned.
- int error() const { return error_; }
+ /// \brief Returns SQL text from when this query was prepared.
+ const char* sql() const { return sqlite3_sql(stmt_); }
- /// \brief Returns error() as a tensorflow::Status.
- Status status() const;
+ /// \brief Number of bytes bound since last *Reset().
+ uint64 size() { return size_; }
- /// \brief Finalize statement object.
+ /// \brief Executes query for fetching arbitrary rows.
///
- /// Please note that the destructor can also do this. This method is
- /// a no-op if already closed.
- Status Close();
+ /// `is_done` will always be set to true unless SQLITE_ROW is
+ /// returned by the underlying API. If status() is already in an
+ /// error state, then this method is a no-op and the existing status
+ /// is returned.
+ ///
+ /// The OrDie version returns `!is_done` which, if true, indicates a
+ /// row is available.
+ ///
+ /// This statement should be Reset() or destructed when when finished
+ /// with the result.
+ Status Step(bool* is_done);
+ bool StepOrDie() TF_MUST_USE_RESULT;
- /// \brief Executes query and/or fetches next row.
+ /// \brief Executes query when only one row is desired.
///
- /// `isDone` will always be set to true unless SQLITE_ROW is returned
- /// by the underlying API. If status() is already in an error state,
- /// then this method is a no-op and the existing status is returned.
- Status Step(bool* isDone);
+ /// If a row isn't returned, an internal error Status is returned
+ /// that won't be reflected in the connection error state.
+ ///
+ /// This statement should be Reset() or destructed when when finished
+ /// with the result.
+ Status StepOnce();
+ const SqliteStatement& StepOnceOrDie();
- /// \brief Executes query that returns no data.
+ /// \brief Executes query, ensures zero rows returned, then Reset().
///
- /// This helper calls Step(), ensures SQLITE_DONE was returned, then
- /// resets the statement and clears the bindings. If status() is
- /// already in an error state, then this method is a no-op and the
- /// existing status is returned.
+ /// If a row is returned, an internal error Status is returned that
+ /// won't be reflected in the connection error state.
Status StepAndReset();
+ void StepAndResetOrDie();
/// \brief Resets statement so it can be executed again.
///
- /// - Resets the prepared statement
- /// - Sets all Bind*() values to NULL
- ///
- /// Support for calling sqlite3_reset() and sqlite3_clear_bindings()
- /// independently may be added in the future if a compelling use case
- /// can be demonstrated.
+ /// Implementation note: This method diverges from canonical API
+ /// behavior by calling sqlite3_clear_bindings() in addition to
+ /// sqlite3_reset(). That makes the veneer safer; we haven't found a
+ /// super compelling reason yet to call them independently.
void Reset();
/// \brief Binds signed 64-bit integer to 1-indexed query parameter.
void BindInt(int parameter, int64 value) {
- Update(sqlite3_bind_int64(stmt_, parameter, value));
+ Update(sqlite3_bind_int64(stmt_, parameter, value), parameter);
+ size_ += sizeof(int64);
}
- void BindInt(const string& parameter, int64 value) {
+ void BindInt(const char* parameter, int64 value) {
BindInt(GetParameterIndex(parameter), value);
}
/// \brief Binds double to 1-indexed query parameter.
void BindDouble(int parameter, double value) {
- Update(sqlite3_bind_double(stmt_, parameter, value));
+ Update(sqlite3_bind_double(stmt_, parameter, value), parameter);
+ size_ += sizeof(double);
}
- void BindDouble(const string& parameter, double value) {
+ void BindDouble(const char* parameter, double value) {
BindDouble(GetParameterIndex(parameter), value);
}
@@ -174,69 +231,67 @@
/// If NUL characters are present, they will still go in the DB and
/// be successfully retrieved by ColumnString(); however, the
/// behavior of these values with SQLite functions is undefined.
- void BindText(int parameter, const string& text) {
+ ///
+ /// When using the unsafe methods, the data must not be changed or
+ /// freed until this statement is Reset() or finalized.
+ void BindText(int parameter, const StringPiece& text) {
Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
- SQLITE_TRANSIENT, SQLITE_UTF8));
+ SQLITE_TRANSIENT, SQLITE_UTF8), parameter);
+ size_ += text.size();
}
- void BindText(const string& parameter, const string& text) {
+ void BindText(const char* parameter, const StringPiece& text) {
BindText(GetParameterIndex(parameter), text);
}
-
- /// \brief Copies binary data to 1-indexed query parameter.
- void BindBlob(int parameter, const string& blob) {
- Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
- SQLITE_TRANSIENT));
- }
- void BindBlob(const string& parameter, const string& blob) {
- BindBlob(GetParameterIndex(parameter), blob);
- }
-
- /// \brief Binds UTF-8 text to 1-indexed query parameter.
- ///
- /// The contents of `text` must not be changed or freed until Reset()
- /// or Close() is called.
- ///
- /// If NUL characters are present, they will still go in the DB and
- /// be successfully retrieved by ColumnString(); however, the
- /// behavior of these values with SQLite functions is undefined.
- void BindTextUnsafe(int parameter, const string& text) {
+ void BindTextUnsafe(int parameter, const StringPiece& text) {
Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
- SQLITE_STATIC, SQLITE_UTF8));
+ SQLITE_STATIC, SQLITE_UTF8), parameter);
+ size_ += text.size();
}
- void BindTextUnsafe(const string& parameter, const string& text) {
+ void BindTextUnsafe(const char* parameter, const StringPiece& text) {
BindTextUnsafe(GetParameterIndex(parameter), text);
}
- /// \brief Binds binary data to 1-indexed query parameter.
+ /// \brief Copies binary data to 1-indexed query parameter.
///
- /// The contents of `blob` must not be changed or freed until Reset()
- /// or Close() is called.
- void BindBlobUnsafe(int parameter, const string& blob) {
+ /// When using the unsafe methods, the data must not be changed or
+ /// freed until this statement is Reset() or finalized.
+ void BindBlob(int parameter, const StringPiece& blob) {
Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
- SQLITE_STATIC));
+ SQLITE_TRANSIENT), parameter);
+ size_ += blob.size();
}
- void BindBlobUnsafe(const string& parameter, const string& text) {
+ void BindBlob(const char* parameter, const StringPiece& blob) {
+ BindBlob(GetParameterIndex(parameter), blob);
+ }
+ void BindBlobUnsafe(int parameter, const StringPiece& blob) {
+ Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
+ SQLITE_STATIC), parameter);
+ size_ += blob.size();
+ }
+ void BindBlobUnsafe(const char* parameter, const StringPiece& text) {
BindBlobUnsafe(GetParameterIndex(parameter), text);
}
/// \brief Returns number of columns in result set.
- int ColumnCount() TF_MUST_USE_RESULT { return sqlite3_column_count(stmt_); }
+ int ColumnCount() const TF_MUST_USE_RESULT {
+ return sqlite3_column_count(stmt_);
+ }
/// \brief Returns type of 0-indexed column value in row data.
///
/// Please note that SQLite is dynamically typed and the type of a
/// particular column can vary from row to row.
- int ColumnType(int column) TF_MUST_USE_RESULT {
+ int ColumnType(int column) const TF_MUST_USE_RESULT {
return sqlite3_column_type(stmt_, column);
}
/// \brief Returns 0-indexed column from row result coerced as an integer.
- int64 ColumnInt(int column) TF_MUST_USE_RESULT {
+ int64 ColumnInt(int column) const TF_MUST_USE_RESULT {
return sqlite3_column_int64(stmt_, column);
}
/// \brief Returns 0-indexed column from row result coerced as a double.
- double ColumnDouble(int column) TF_MUST_USE_RESULT {
+ double ColumnDouble(int column) const TF_MUST_USE_RESULT {
return sqlite3_column_double(stmt_, column);
}
@@ -244,80 +299,141 @@
///
/// NULL values are returned as empty string. This method should be
/// used for both BLOB and TEXT columns. See also: ColumnType().
- string ColumnString(int column) TF_MUST_USE_RESULT {
+ string ColumnString(int column) const TF_MUST_USE_RESULT {
auto data = sqlite3_column_blob(stmt_, column);
- if (data == nullptr) {
- return "";
- }
+ if (data == nullptr) return "";
return {static_cast<const char*>(data),
static_cast<size_t>(ColumnSize(column))};
}
/// \brief Returns pointer to binary data at 0-indexed column.
///
- /// The returned memory will be mutated or freed the next time
- /// Step() or Reset() is called. No NUL terminator is added. See
- /// ColumnSize(). Please note that an empty BLOB is NULL.
- const char* ColumnStringUnsafe(int column) TF_MUST_USE_RESULT {
- return static_cast<const char*>(sqlite3_column_blob(stmt_, column));
+ /// Empty values are returned as NULL. The returned memory will no
+ /// longer be valid the next time Step() or Reset() is called. No NUL
+ /// terminator is added.
+ StringPiece ColumnStringUnsafe(int column) const TF_MUST_USE_RESULT {
+ return {static_cast<const char*>(sqlite3_column_blob(stmt_, column)),
+ static_cast<size_t>(ColumnSize(column))};
}
/// \brief Returns number of bytes stored at 0-indexed column.
- int ColumnSize(int column) TF_MUST_USE_RESULT {
+ int ColumnSize(int column) const TF_MUST_USE_RESULT {
return sqlite3_column_bytes(stmt_, column);
}
- private:
- friend Sqlite;
- SqliteStatement(sqlite3_stmt* stmt, int error,
- std::unique_ptr<string> prepare_error_sql)
- : stmt_(stmt),
- error_(error),
- prepare_error_sql_(std::move(prepare_error_sql)) {}
- void CloseOrLog();
+ /// \brief Move constructor, after which <other> is reset to empty.
+ SqliteStatement(SqliteStatement&& other) noexcept
+ : stmt_(other.stmt_),
+ db_(std::move(other.db_)),
+ bind_error_(other.bind_error_) {
+ other.stmt_ = nullptr;
+ other.bind_error_ = SQLITE_OK;
+ }
- void Update(int rc) {
+ /// \brief Move assignment, after which <other> is reset to empty.
+ SqliteStatement& operator=(SqliteStatement&& other) noexcept {
+ if (&other != this) {
+ sqlite3_finalize(stmt_);
+ stmt_ = other.stmt_;
+ bind_error_ = other.bind_error_;
+ db_ = std::move(other.db_);
+ size_ = 0;
+ other.stmt_ = nullptr;
+ other.bind_error_ = SQLITE_OK;
+ }
+ return *this;
+ }
+
+ private:
+ friend class Sqlite;
+
+ SqliteStatement(sqlite3_stmt* stmt, std::shared_ptr<Sqlite> db) noexcept
+ : stmt_(stmt), db_(std::move(db)) {}
+
+ void Update(int rc, int parameter) {
+ // Binding strings can fail if they exceed length limit.
if (TF_PREDICT_FALSE(rc != SQLITE_OK)) {
- if (error_ == SQLITE_OK) {
- error_ = rc;
+ if (bind_error_ == SQLITE_OK) {
+ bind_error_ = rc;
+ bind_error_parameter_ = parameter;
}
}
}
- int GetParameterIndex(const string& parameter) {
- // Each call to this function requires O(n) strncmp().
- int index = sqlite3_bind_parameter_index(stmt_, parameter.c_str());
- if (TF_PREDICT_FALSE(index == 0)) {
- Update(SQLITE_NOTFOUND);
- }
+ int GetParameterIndex(const char* parameter) {
+ int index = sqlite3_bind_parameter_index(stmt_, parameter);
+ DCHECK(index > 0); // OK to compile away since it'll fail again
return index;
}
- sqlite3_stmt* stmt_;
- int error_;
- std::unique_ptr<string> prepare_error_sql_;
+ sqlite3_stmt* stmt_ = nullptr;
+ std::shared_ptr<Sqlite> db_;
+ int bind_error_ = SQLITE_OK;
+ int bind_error_parameter_ = 0;
+ uint64 size_ = 0;
TF_DISALLOW_COPY_AND_ASSIGN(SqliteStatement);
};
-inline SqliteStatement::SqliteStatement(SqliteStatement&& other)
- : stmt_(other.stmt_),
- error_(other.error_),
- prepare_error_sql_(std::move(other.prepare_error_sql_)) {
- other.stmt_ = nullptr;
- other.error_ = SQLITE_OK;
-}
-
-inline SqliteStatement& SqliteStatement::operator=(SqliteStatement&& other) {
- if (&other != this) {
- CloseOrLog();
- stmt_ = other.stmt_;
- error_ = other.error_;
- prepare_error_sql_ = std::move(other.prepare_error_sql_);
- other.stmt_ = nullptr;
- other.error_ = SQLITE_OK;
+/// \brief Reentrant SQLite connection object lock
+///
+/// This is a no-op if SQLITE_OPEN_NOMUTEX was used.
+class SCOPED_LOCKABLE SqliteLock {
+ public:
+ explicit SqliteLock(Sqlite& db) EXCLUSIVE_LOCK_FUNCTION(db)
+ : mutex_(sqlite3_db_mutex(db.db_)) {
+ sqlite3_mutex_enter(mutex_);
}
- return *this;
+ SqliteLock(Sqlite& db, std::try_to_lock_t) EXCLUSIVE_LOCK_FUNCTION(db)
+ : mutex_(sqlite3_db_mutex(db.db_)) {
+ if (TF_PREDICT_FALSE(sqlite3_mutex_try(mutex_) != SQLITE_OK)) {
+ is_locked_ = false;
+ }
+ }
+ ~SqliteLock() UNLOCK_FUNCTION() {
+ if (is_locked_) sqlite3_mutex_leave(mutex_);
+ }
+ explicit operator bool() const { return is_locked_; }
+
+ private:
+ sqlite3_mutex* const mutex_;
+ bool is_locked_ = true;
+ TF_DISALLOW_COPY_AND_ASSIGN(SqliteLock);
+};
+#define SqliteLock(x) static_assert(0, "sqlite_lock_decl_missing_name");
+
+/// \brief SQLite transaction scope.
+///
+/// This class acquires an exclusive lock on the connection object (if
+/// mutexes weren't disabled) and runs BEGIN / ROLLBACK automatically.
+/// Unlike SqliteLock this scope is non-reentrant. To avoid program
+/// crashes, business logic should use the EXCLUSIVE_LOCK_FUNCTION and
+/// LOCKS_EXCLUDED annotations as much as possible.
+class SCOPED_LOCKABLE SqliteTransaction {
+ public:
+ /// \brief Locks db and begins deferred transaction.
+ ///
+ /// This will crash if a transaction is already active.
+ explicit SqliteTransaction(Sqlite& db) EXCLUSIVE_LOCK_FUNCTION(db);
+
+ /// \brief Runs ROLLBACK and unlocks.
+ ~SqliteTransaction() UNLOCK_FUNCTION();
+
+ /// \brief Commits transaction.
+ ///
+ /// If this is successful, a new transaction will be started, which
+ /// is rolled back when exiting the scope.
+ Status Commit();
+
+ private:
+ void Begin();
+ Sqlite* const db_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SqliteTransaction);
+};
+
+inline SqliteStatement Sqlite::PrepareOrDie(const StringPiece& sql) {
+ return Prepare(sql).ValueOrDie();
}
} // namespace tensorflow
diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc
index 29772b8..f93b3d8 100644
--- a/tensorflow/core/lib/db/sqlite_test.cc
+++ b/tensorflow/core/lib/db/sqlite_test.cc
@@ -14,13 +14,13 @@
==============================================================================*/
#include "tensorflow/core/lib/db/sqlite.h"
-#include <limits.h>
#include <array>
+#include <climits>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -29,23 +29,22 @@
class SqliteTest : public ::testing::Test {
protected:
void SetUp() override {
- db_ = Sqlite::Open(":memory:").ValueOrDie();
- auto stmt = db_->Prepare("CREATE TABLE T (a BLOB, b BLOB)");
- TF_ASSERT_OK(stmt.StepAndReset());
+ db_ = Sqlite::OpenOrDie(":memory:");
+ db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie();
}
std::shared_ptr<Sqlite> db_;
bool is_done_;
};
TEST_F(SqliteTest, InsertAndSelectInt) {
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
stmt.BindInt(1, 3);
stmt.BindInt(2, -7);
TF_ASSERT_OK(stmt.StepAndReset());
stmt.BindInt(1, 123);
stmt.BindInt(2, -123);
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT a, b FROM T ORDER BY b");
+ stmt = db_->PrepareOrDie("SELECT a, b FROM T ORDER BY b");
TF_ASSERT_OK(stmt.Step(&is_done_));
ASSERT_FALSE(is_done_);
EXPECT_EQ(123, stmt.ColumnInt(0));
@@ -59,11 +58,11 @@
}
TEST_F(SqliteTest, InsertAndSelectDouble) {
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
stmt.BindDouble(1, 6.28318530);
stmt.BindDouble(2, 1.61803399);
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT a, b FROM T");
+ stmt = db_->PrepareOrDie("SELECT a, b FROM T");
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(6.28318530, stmt.ColumnDouble(0));
EXPECT_EQ(1.61803399, stmt.ColumnDouble(1));
@@ -74,11 +73,11 @@
TEST_F(SqliteTest, NulCharsInString) {
string s; // XXX: Want to write {2, '\0'} but not sure why not.
s.append(static_cast<size_t>(2), '\0');
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
stmt.BindBlob(1, s);
stmt.BindText(2, s);
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT a, b FROM T");
+ stmt = db_->PrepareOrDie("SELECT a, b FROM T");
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(2, stmt.ColumnSize(0));
EXPECT_EQ(2, stmt.ColumnString(0).size());
@@ -92,58 +91,38 @@
TEST_F(SqliteTest, Unicode) {
string s = "要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非";
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
stmt.BindBlob(1, s);
stmt.BindText(2, s);
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT a, b FROM T");
+ stmt = db_->PrepareOrDie("SELECT a, b FROM T");
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(s, stmt.ColumnString(0));
EXPECT_EQ(s, stmt.ColumnString(1));
}
TEST_F(SqliteTest, StepAndResetClearsBindings) {
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
stmt.BindInt(1, 1);
stmt.BindInt(2, 123);
TF_ASSERT_OK(stmt.StepAndReset());
stmt.BindInt(1, 2);
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT b FROM T ORDER BY a");
+ stmt = db_->PrepareOrDie("SELECT b FROM T ORDER BY a");
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(123, stmt.ColumnInt(0));
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(SQLITE_NULL, stmt.ColumnType(0));
}
-TEST_F(SqliteTest, CloseBeforeFinalizeFails) {
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
- Status s = db_->Close();
- EXPECT_FALSE(s.ok());
-}
-
-// Rather than bothering to check the status code of creating a
-// statement and every single bind call afterwards, SqliteStatement
-// is designed to carry the first error state forward to Step().
-TEST_F(SqliteTest, ErrorPuntingDoesNotReportLibraryAbuse) {
- auto stmt = db_->Prepare("lol cat");
- EXPECT_FALSE(stmt.status().ok());
- EXPECT_EQ(SQLITE_ERROR, stmt.error());
- stmt.BindInt(1, 1);
- stmt.BindInt(2, 2);
- Status s = stmt.Step(&is_done_);
- EXPECT_EQ(SQLITE_ERROR, stmt.error()); // first error of several
- EXPECT_FALSE(s.ok());
-}
-
TEST_F(SqliteTest, SafeBind) {
string s = "hello";
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
stmt.BindBlob(1, s);
stmt.BindText(2, s);
s.at(0) = 'y';
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT a, b FROM T");
+ stmt = db_->PrepareOrDie("SELECT a, b FROM T");
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ("hello", stmt.ColumnString(0));
EXPECT_EQ("hello", stmt.ColumnString(1));
@@ -151,42 +130,42 @@
TEST_F(SqliteTest, UnsafeBind) {
string s = "hello";
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
stmt.BindBlobUnsafe(1, s);
stmt.BindTextUnsafe(2, s);
s.at(0) = 'y';
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT a, b FROM T");
+ stmt = db_->PrepareOrDie("SELECT a, b FROM T");
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ("yello", stmt.ColumnString(0));
EXPECT_EQ("yello", stmt.ColumnString(1));
}
TEST_F(SqliteTest, UnsafeColumn) {
- auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
stmt.BindInt(1, 1);
stmt.BindText(2, "hello");
TF_ASSERT_OK(stmt.StepAndReset());
stmt.BindInt(1, 2);
stmt.BindText(2, "there");
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT b FROM T ORDER BY a");
+ stmt = db_->PrepareOrDie("SELECT b FROM T ORDER BY a");
TF_ASSERT_OK(stmt.Step(&is_done_));
- const char* p = stmt.ColumnStringUnsafe(0);
- EXPECT_EQ('h', *p);
+ StringPiece p = stmt.ColumnStringUnsafe(0);
+ EXPECT_EQ('h', *p.data());
TF_ASSERT_OK(stmt.Step(&is_done_));
// This will actually happen, but it's not safe to test this behavior.
- // EXPECT_EQ('t', *p);
+ // EXPECT_EQ('t', *p.data());
}
TEST_F(SqliteTest, NamedParameterBind) {
- auto stmt = db_->Prepare("INSERT INTO T (a) VALUES (:a)");
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a) VALUES (:a)");
stmt.BindText(":a", "lol");
TF_ASSERT_OK(stmt.StepAndReset());
- stmt = db_->Prepare("SELECT COUNT(*) FROM T");
+ stmt = db_->PrepareOrDie("SELECT COUNT(*) FROM T");
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_EQ(1, stmt.ColumnInt(0));
- stmt = db_->Prepare("SELECT a FROM T");
+ stmt = db_->PrepareOrDie("SELECT a FROM T");
TF_ASSERT_OK(stmt.Step(&is_done_));
EXPECT_FALSE(is_done_);
EXPECT_EQ("lol", stmt.ColumnString(0));
@@ -195,57 +174,107 @@
TEST_F(SqliteTest, Statement_DefaultConstructor) {
SqliteStatement stmt;
EXPECT_FALSE(stmt);
- EXPECT_FALSE(stmt.StepAndReset().ok());
- stmt = db_->Prepare("INSERT INTO T (a) VALUES (1)");
+ stmt = db_->PrepareOrDie("INSERT INTO T (a) VALUES (1)");
EXPECT_TRUE(stmt);
EXPECT_TRUE(stmt.StepAndReset().ok());
}
TEST_F(SqliteTest, Statement_MoveConstructor) {
- SqliteStatement stmt{db_->Prepare("INSERT INTO T (a) VALUES (1)")};
+ SqliteStatement stmt{db_->PrepareOrDie("INSERT INTO T (a) VALUES (1)")};
EXPECT_TRUE(stmt.StepAndReset().ok());
}
TEST_F(SqliteTest, Statement_MoveAssignment) {
- SqliteStatement stmt1 = db_->Prepare("INSERT INTO T (a) VALUES (1)");
+ SqliteStatement stmt1 = db_->PrepareOrDie("INSERT INTO T (a) VALUES (1)");
SqliteStatement stmt2;
EXPECT_TRUE(stmt1.StepAndReset().ok());
- EXPECT_FALSE(stmt2.StepAndReset().ok());
+ EXPECT_FALSE(stmt2);
stmt2 = std::move(stmt1);
EXPECT_TRUE(stmt2.StepAndReset().ok());
}
TEST_F(SqliteTest, PrepareFailed) {
- SqliteStatement s = db_->Prepare("SELECT");
- EXPECT_FALSE(s.status().ok());
- EXPECT_NE(string::npos, s.status().error_message().find("SELECT"));
+ SqliteLock lock(*db_);
+ Status s = db_->Prepare("SELECT").status();
+ ASSERT_FALSE(s.ok());
+ EXPECT_NE(string::npos, s.error_message().find("SELECT"));
+ EXPECT_EQ(SQLITE_ERROR, db_->errcode());
}
TEST_F(SqliteTest, BindFailed) {
- SqliteStatement s = db_->Prepare("INSERT INTO T (a) VALUES (123)");
- EXPECT_TRUE(s.status().ok());
- EXPECT_EQ("", s.status().error_message());
- s.BindInt(1, 123);
- EXPECT_FALSE(s.status().ok());
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a) VALUES (123)");
+ stmt.BindInt(1, 123);
+ Status s = stmt.StepOnce();
EXPECT_NE(string::npos,
- s.status().error_message().find("INSERT INTO T (a) VALUES (123)"));
+ s.error_message().find("INSERT INTO T (a) VALUES (123)"))
+ << s.error_message();
}
TEST_F(SqliteTest, SnappyExtension) {
- auto stmt = db_->Prepare("SELECT UNSNAP(SNAP(?))");
+ auto stmt = db_->PrepareOrDie("SELECT UNSNAP(SNAP(?))");
stmt.BindText(1, "hello");
- TF_ASSERT_OK(stmt.Step(&is_done_));
- EXPECT_FALSE(is_done_);
- EXPECT_EQ("hello", stmt.ColumnString(0));
+ EXPECT_EQ("hello", stmt.StepOnceOrDie().ColumnString(0));
}
TEST_F(SqliteTest, SnappyBinaryCompatibility) {
- auto stmt = db_->Prepare(
- "SELECT UNSNAP(X'03207C746F6461792069732074686520656E64206F66207468652"
- "072657075626C6963')");
- TF_ASSERT_OK(stmt.Step(&is_done_));
- EXPECT_FALSE(is_done_);
- EXPECT_EQ("today is the end of the republic", stmt.ColumnString(0));
+ EXPECT_EQ(
+ "today is the end of the republic",
+ db_->PrepareOrDie("SELECT UNSNAP(X'03207C746F6461792069732074686520656E64"
+ "206F66207468652072657075626C6963')")
+ .StepOnceOrDie()
+ .ColumnString(0));
+}
+
+TEST(SqliteOpenTest, CloseConnectionBeforeStatement_KeepsConnectionOpen) {
+ auto s = Sqlite::OpenOrDie(":memory:")->PrepareOrDie("SELECT ? + ?");
+ s.BindInt(1, 7);
+ s.BindInt(2, 3);
+ EXPECT_EQ(10, s.StepOnceOrDie().ColumnInt(0));
+}
+
+TEST_F(SqliteTest, TransactionRollback) {
+ {
+ SqliteTransaction txn(*db_);
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
+ stmt.BindDouble(1, 6.28318530);
+ stmt.BindDouble(2, 1.61803399);
+ TF_ASSERT_OK(stmt.StepAndReset());
+ }
+ EXPECT_EQ(
+ 0,
+ db_->PrepareOrDie("SELECT COUNT(*) FROM T").StepOnceOrDie().ColumnInt(0));
+}
+
+TEST_F(SqliteTest, TransactionCommit) {
+ {
+ SqliteTransaction txn(*db_);
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
+ stmt.BindDouble(1, 6.28318530);
+ stmt.BindDouble(2, 1.61803399);
+ TF_ASSERT_OK(stmt.StepAndReset());
+ TF_ASSERT_OK(txn.Commit());
+ }
+ EXPECT_EQ(
+ 1,
+ db_->PrepareOrDie("SELECT COUNT(*) FROM T").StepOnceOrDie().ColumnInt(0));
+}
+
+TEST_F(SqliteTest, TransactionCommitMultipleTimes) {
+ {
+ SqliteTransaction txn(*db_);
+ auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
+ stmt.BindDouble(1, 6.28318530);
+ stmt.BindDouble(2, 1.61803399);
+ TF_ASSERT_OK(stmt.StepAndReset());
+ TF_ASSERT_OK(txn.Commit());
+ stmt.BindDouble(1, 6.28318530);
+ stmt.BindDouble(2, 1.61803399);
+ TF_ASSERT_OK(stmt.StepAndReset());
+ TF_ASSERT_OK(txn.Commit());
+ }
+ EXPECT_EQ(
+ 2,
+ db_->PrepareOrDie("SELECT COUNT(*) FROM T").StepOnceOrDie().ColumnInt(0));
}
} // namespace
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 9329749..ddc4910 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1413,6 +1413,10 @@
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
c->set_output(0, c->Vector(2));
c->set_output(1, c->Scalar());
+ if (c->input_handle_shapes_and_types(0)) {
+ c->set_output_handle_shapes_and_types(
+ 0, *c->input_handle_shapes_and_types(0));
+ }
return Status::OK();
})
.Doc(R"doc(
diff --git a/tensorflow/core/platform/cloud/expiring_lru_cache.h b/tensorflow/core/platform/cloud/expiring_lru_cache.h
index 3fc23a4..c738497 100644
--- a/tensorflow/core/platform/cloud/expiring_lru_cache.h
+++ b/tensorflow/core/platform/cloud/expiring_lru_cache.h
@@ -88,6 +88,13 @@
return s;
}
+ /// Clear the cache.
+ void Clear() {
+ mutex_lock lock(mu_);
+ cache_.clear();
+ lru_list_.clear();
+ }
+
/// Accessors for cache parameters.
uint64 max_age() const { return max_age_; }
size_t max_entries() const { return max_entries_; }
diff --git a/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc b/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
index 8f8d574..3bc6db3 100644
--- a/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
+++ b/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
@@ -152,5 +152,27 @@
EXPECT_EQ(num_compute_calls, 6);
}
+TEST(ExpiringLRUCacheTest, Clear) {
+ ExpiringLRUCache<int> cache(1, 4);
+ cache.Insert("a", 1);
+ cache.Insert("b", 2);
+ cache.Insert("c", 3);
+ cache.Insert("d", 4);
+ int value = 0;
+ EXPECT_TRUE(cache.Lookup("a", &value));
+ EXPECT_EQ(value, 1);
+ EXPECT_TRUE(cache.Lookup("b", &value));
+ EXPECT_EQ(value, 2);
+ EXPECT_TRUE(cache.Lookup("c", &value));
+ EXPECT_EQ(value, 3);
+ EXPECT_TRUE(cache.Lookup("d", &value));
+ EXPECT_EQ(value, 4);
+ cache.Clear();
+ EXPECT_FALSE(cache.Lookup("a", &value));
+ EXPECT_FALSE(cache.Lookup("b", &value));
+ EXPECT_FALSE(cache.Lookup("c", &value));
+ EXPECT_FALSE(cache.Lookup("d", &value));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc
index 6831600..0375af5 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.cc
+++ b/tensorflow/core/platform/cloud/file_block_cache.cc
@@ -237,6 +237,14 @@
}
}
+void FileBlockCache::Flush() {
+ mutex_lock lock(mu_);
+ block_map_.clear();
+ lru_list_.clear();
+ lra_list_.clear();
+ cache_size_ = 0;
+}
+
void FileBlockCache::RemoveFile(const string& filename) {
mutex_lock lock(mu_);
RemoveFile_Locked(filename);
diff --git a/tensorflow/core/platform/cloud/file_block_cache.h b/tensorflow/core/platform/cloud/file_block_cache.h
index 74e792a..5c180e2 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.h
+++ b/tensorflow/core/platform/cloud/file_block_cache.h
@@ -90,6 +90,9 @@
/// Remove all cached blocks for `filename`.
void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_);
+ /// Remove all cached data.
+ void Flush() LOCKS_EXCLUDED(mu_);
+
/// Accessors for cache parameters.
size_t block_size() const { return block_size_; }
size_t max_bytes() const { return max_bytes_; }
diff --git a/tensorflow/core/platform/cloud/file_block_cache_test.cc b/tensorflow/core/platform/cloud/file_block_cache_test.cc
index ae87e0d..596fdbf 100644
--- a/tensorflow/core/platform/cloud/file_block_cache_test.cc
+++ b/tensorflow/core/platform/cloud/file_block_cache_test.cc
@@ -495,5 +495,25 @@
EXPECT_EQ(1, num_requests);
}
+
+TEST(FileBlockCacheTest, Flush) {
+ int calls = 0;
+ auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
+ char* buffer, size_t* bytes_transferred) {
+ calls++;
+ memset(buffer, 'x', n);
+ *bytes_transferred = n;
+ return Status::OK();
+ };
+ FileBlockCache cache(16, 32, 0, fetcher);
+ std::vector<char> out;
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
+ EXPECT_EQ(calls, 1);
+ cache.Flush();
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
+ EXPECT_EQ(calls, 2);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index dffd1c4..970a6b1 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -1377,6 +1377,15 @@
return Status::OK();
}
+// Flushes all caches for filesystem metadata and file contents. Useful for
+// reclaiming memory once filesystem operations are done (e.g. model is loaded),
+// or for resetting the filesystem to a consistent state.
+void GcsFileSystem::FlushCaches() {
+ file_block_cache_->Flush();
+ stat_cache_->Clear();
+ matching_paths_cache_->Clear();
+}
+
// Creates an HttpRequest and sets several parameters that are common to all
// requests. All code (in GcsFileSystem) that creates an HttpRequest should
// go through this method, rather than directly using http_request_factory_.
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 731f97a..adde161 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -84,6 +84,8 @@
Status DeleteRecursively(const string& dirname, int64* undeleted_files,
int64* undeleted_dirs) override;
+ void FlushCaches() override;
+
/// These accessors are mainly for testing purposes, to verify that the
/// environment variables that control these parameters are handled correctly.
size_t block_size() const { return file_block_cache_->block_size(); }
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 32bd946..772aec5 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -195,6 +195,49 @@
EXPECT_EQ("0123", result);
}
+TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) {
+ // Our underlying file in this test is a 15 byte file with contents
+ // "0123456789abcde".
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://storage.googleapis.com/bucket/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 0-8\n"
+ "Timeouts: 5 1 20\n",
+ "012345678"),
+ new FakeHttpRequest(
+ "Uri: https://storage.googleapis.com/bucket/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 0-8\n"
+ "Timeouts: 5 1 20\n",
+ "012345678")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 9 /* block size */, 18 /* max bytes */,
+ 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */,
+ 0 /* initial retry delay */, kTestTimeoutConfig);
+
+ char scratch[100];
+ StringPiece result;
+ std::unique_ptr<RandomAccessFile> file;
+ TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
+ // Read the first chunk. The cache will be populated with the first block of
+ // 9 bytes.
+ scratch[5] = 'x';
+ TF_EXPECT_OK(file->Read(0, 4, &result, scratch));
+ EXPECT_EQ("0123", result);
+ EXPECT_EQ(scratch[5], 'x'); // Make sure we only copied 4 bytes.
+ // Flush caches and read the second chunk. This will be a cache miss, and
+ // the same block will be fetched again.
+ fs.FlushCaches();
+ TF_EXPECT_OK(file->Read(4, 4, &result, scratch));
+ EXPECT_EQ("4567", result);
+}
+
TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) {
// Our underlying file in this test is a 16 byte file with contents
// "0123456789abcdef".
@@ -1270,6 +1313,50 @@
}
}
+TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+ "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ "{\"items\": [ "
+ " { \"name\": \"path/subpath/file2.txt\" }]}"),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+ "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ "{\"items\": [ "
+ " { \"name\": \"path/subpath/file2.txt\" }]}")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */,
+ 0 /* initial retry delay*/, kTestTimeoutConfig);
+
+ // This loop should trigger the first HTTP request to GCS.
+ for (int i = 0; i < 10; i++) {
+ std::vector<string> result;
+ TF_EXPECT_OK(
+ fs.GetMatchingPaths("gs://bucket/path/subpath/file2.txt", &result));
+ EXPECT_EQ(std::vector<string>({"gs://bucket/path/subpath/file2.txt"}),
+ result);
+ }
+ // After flushing caches, there should be another (identical) request to GCS.
+ fs.FlushCaches();
+ for (int i = 0; i < 10; i++) {
+ std::vector<string> result;
+ TF_EXPECT_OK(
+ fs.GetMatchingPaths("gs://bucket/path/subpath/file2.txt", &result));
+ EXPECT_EQ(std::vector<string>({"gs://bucket/path/subpath/file2.txt"}),
+ result);
+ }
+}
+
TEST(GcsFileSystemTest, DeleteFile) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
@@ -1895,6 +1982,50 @@
}
}
+TEST(GcsFileSystemTest, Stat_Cache_Flush) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "file.txt?fields=size%2Cupdated\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ strings::StrCat("{\"size\": \"1010\","
+ "\"updated\": \"2016-04-29T23:15:24.896Z\"}")),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "file.txt?fields=size%2Cupdated\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ strings::StrCat("{\"size\": \"1010\","
+ "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */,
+ 0 /* initial retry delay*/, kTestTimeoutConfig);
+ // There should be a single HTTP request to GCS for fs.Stat in this loop.
+ for (int i = 0; i < 10; i++) {
+ FileStatistics stat;
+ TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
+ EXPECT_EQ(1010, stat.length);
+ EXPECT_NEAR(1461971724896, stat.mtime_nsec / 1000 / 1000, 1);
+ EXPECT_FALSE(stat.is_directory);
+ }
+ // After flushing caches, there should be a second request to GCS for fs.Stat.
+ fs.FlushCaches();
+ for (int i = 0; i < 10; i++) {
+ FileStatistics stat;
+ TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
+ EXPECT_EQ(1010, stat.length);
+ EXPECT_NEAR(1461971724896, stat.mtime_nsec / 1000 / 1000, 1);
+ EXPECT_FALSE(stat.is_directory);
+ }
+}
+
TEST(GcsFileSystemTest, IsDirectory_NotFound) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
diff --git a/tensorflow/core/platform/default/stacktrace.h b/tensorflow/core/platform/default/stacktrace.h
index 5f30732..436716d 100644
--- a/tensorflow/core/platform/default/stacktrace.h
+++ b/tensorflow/core/platform/default/stacktrace.h
@@ -17,12 +17,61 @@
#define TENSORFLOW_CORE_PLATFORM_DEFAULT_STACKTRACE_H_
#include "tensorflow/core/platform/platform.h"
+#if !defined(IS_MOBILE_PLATFORM) && defined(PLATFORM_POSIX) && \
+ (defined(__clang__) || defined(__GNUC__))
+#define TF_GENERATE_BACKTRACE
+#endif
+
+#if defined(TF_GENERATE_BACKTRACE)
+#include <dlfcn.h>
+#include <execinfo.h>
+#include <stdio.h>
+#include <string.h>
+#include <unistd.h>
+#endif // defined(TF_GENERATE_BACKTRACE)
+
+#include <sstream>
+#include <string>
+#include "tensorflow/core/platform/abi.h"
namespace tensorflow {
-inline string CurrentStackTrace() { return "No stack trace available"; }
+// Function to create a pretty stacktrace.
+inline std::string CurrentStackTrace() {
+#if defined(TF_GENERATE_BACKTRACE)
+ std::stringstream ss("");
+ ss << "*** Begin stack trace ***" << std::endl;
-inline void DebugWriteToString(const char* data, void* arg) {}
+ // Get the mangled stack trace.
+ int buffer_size = 128;
+ void* trace[128];
+ buffer_size = backtrace(trace, buffer_size);
+
+ for (int i = 0; i < buffer_size; ++i) {
+ const char* symbol = "";
+ Dl_info info;
+ if (dladdr(trace[i], &info)) {
+ if (info.dli_sname != nullptr) {
+ symbol = info.dli_sname;
+ }
+ }
+
+ std::string demangled = tensorflow::port::MaybeAbiDemangle(symbol);
+ if (demangled.length()) {
+ ss << "\t" << demangled << std::endl;
+ } else {
+ ss << "\t" << symbol << std::endl;
+ }
+ }
+
+ ss << "*** End stack trace ***" << std::endl;
+ return ss.str();
+#endif // defined(TF_GENERATE_BACKTRACE)
+}
+
+inline void DebugWriteToString(const char* data, void* arg) {
+ reinterpret_cast<std::string*>(arg)->append(data);
+}
// A dummy class that does nothing. Someday, add real support.
class SavedStackTrace {
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 816ffa4..d617ff1 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -108,6 +108,18 @@
return file_system_registry_->Register(scheme, std::move(factory));
}
+Status Env::FlushFileSystemCaches() {
+ std::vector<string> schemes;
+ TF_RETURN_IF_ERROR(GetRegisteredFileSystemSchemes(&schemes));
+ for (const string& scheme : schemes) {
+ FileSystem* fs = nullptr;
+ TF_RETURN_IF_ERROR(
+ GetFileSystemForFile(io::CreateURI(scheme, "", ""), &fs));
+ fs->FlushCaches();
+ }
+ return Status::OK();
+}
+
Status Env::NewRandomAccessFile(const string& fname,
std::unique_ptr<RandomAccessFile>* result) {
FileSystem* fs;
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index a0adf70..557bfa8 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -68,10 +68,13 @@
/// \brief Returns the file system schemes registered for this Env.
virtual Status GetRegisteredFileSystemSchemes(std::vector<string>* schemes);
- // \brief Register a file system for a scheme.
+ /// \brief Register a file system for a scheme.
virtual Status RegisterFileSystem(const string& scheme,
FileSystemRegistry::Factory factory);
+ /// \brief Flush filesystem caches for all registered filesystems.
+ Status FlushFileSystemCaches();
+
/// \brief Creates a brand new random access read-only file with the
/// specified name.
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 233c370..47ddf0c 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -281,6 +281,15 @@
StringPiece scheme, host, path;
io::ParseURI(dir, &scheme, &host, &path);
if (path.empty()) return errors::NotFound(dir, " not found");
+ // The special "flushed" file exists only if the filesystem's caches have
+ // been flushed.
+ if (path == "/flushed") {
+ if (flushed_) {
+ return Status::OK();
+ } else {
+ return errors::NotFound("FlushCaches() not called yet");
+ }
+ }
return Env::Default()->FileExists(io::JoinPath(BaseDir(), path));
}
@@ -295,10 +304,23 @@
}
return Env::Default()->CreateDir(io::JoinPath(BaseDir(), path));
}
+
+ void FlushCaches() override { flushed_ = true; }
+
+ private:
+ bool flushed_ = false;
};
REGISTER_FILE_SYSTEM("tmpdirfs", TmpDirFileSystem);
+TEST_F(DefaultEnvTest, FlushFileSystemCaches) {
+ Env* env = Env::Default();
+ const string flushed = "tmpdirfs://testhost/flushed";
+ EXPECT_EQ(error::Code::NOT_FOUND, env->FileExists(flushed).code());
+ TF_EXPECT_OK(env->FlushFileSystemCaches());
+ TF_EXPECT_OK(env->FileExists(flushed));
+}
+
TEST_F(DefaultEnvTest, RecursivelyCreateDirWithUri) {
Env* env = Env::Default();
const string create_path = "tmpdirfs://testhost/a/b/c/d";
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index 938f5af..1475589 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -73,6 +73,8 @@
return Status(tensorflow::error::FAILED_PRECONDITION, "Not a directory");
}
+void FileSystem::FlushCaches() {}
+
RandomAccessFile::~RandomAccessFile() {}
WritableFile::~WritableFile() {}
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 903df96..d32efce 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -206,6 +206,9 @@
/// * UNIMPLEMENTED - The file factory doesn't support directories.
virtual Status IsDirectory(const string& fname);
+ /// \brief Flushes any cached filesystem objects from memory.
+ virtual void FlushCaches();
+
FileSystem() {}
virtual ~FileSystem();
diff --git a/tensorflow/core/platform/stacktrace_handler.cc b/tensorflow/core/platform/stacktrace_handler.cc
new file mode 100644
index 0000000..ff31c97
--- /dev/null
+++ b/tensorflow/core/platform/stacktrace_handler.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/platform.h"
+
+#if !defined(PLATFORM_GOOGLE) && !defined(IS_MOBILE_PLATFORM) && \
+ defined(PLATFORM_POSIX) && (defined(__clang__) || defined(__GNUC__))
+#define TF_GENERATE_STACKTRACE
+#endif
+
+#if defined(TF_GENERATE_STACKTRACE)
+#include <errno.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/time.h>
+#include <unistd.h>
+#include <string>
+
+#include "tensorflow/core/platform/abi.h"
+#include "tensorflow/core/platform/stacktrace.h"
+
+#endif // defined(TF_GENERATE_STACKTRACE)
+
+namespace tensorflow {
+namespace testing {
+
+#if defined(TF_GENERATE_STACKTRACE)
+// This function will print stacktrace to STDERR.
+// It avoids using malloc, so it makes sure to dump the stack even when the heap
+// is corrupted. However, it can dump mangled symbols.
+inline void SafePrintStackTrace() {
+ static const char begin_msg[] = "*** BEGIN MANGLED STACK TRACE ***\n";
+ (void)write(STDERR_FILENO, begin_msg, strlen(begin_msg));
+
+ int buffer_size = 128;
+ void *trace[128];
+ // Run backtrace to get the size of the stacktrace
+ buffer_size = backtrace(trace, buffer_size);
+
+ // Print a mangled stacktrace to STDERR as safely as possible.
+ backtrace_symbols_fd(trace, buffer_size, STDERR_FILENO);
+
+ static const char end_msg[] = "*** END MANGLED STACK TRACE ***\n\n";
+ (void)write(STDERR_FILENO, end_msg, strlen(end_msg));
+}
+
+static void StacktraceHandler(int sig, siginfo_t *si, void *v) {
+ // Make sure our handler does not deadlock. And this should be the last thing
+ // our program does. Therefore, set a timer to kill the program in 60
+ // seconds.
+ struct itimerval timer;
+ timer.it_value.tv_sec = 60;
+ timer.it_value.tv_usec = 0;
+ timer.it_interval.tv_sec = 0;
+ timer.it_interval.tv_usec = 0;
+ setitimer(ITIMER_REAL, &timer, 0);
+
+ struct sigaction sa_timeout;
+ memset(&sa_timeout, 0, sizeof(sa_timeout));
+ sa_timeout.sa_handler = SIG_DFL;
+ sigaction(SIGALRM, &sa_timeout, 0);
+
+ char buf[128];
+
+ snprintf(buf, sizeof(buf), "*** Received signal %d ***\n", sig);
+ (void)write(STDERR_FILENO, buf, strlen(buf));
+
+ // Print "a" stack trace, as safely as possible.
+ SafePrintStackTrace();
+
+ // Up until this line, we made sure not to allocate memory, to be able to dump
+ // a stack trace even in the event of heap corruption. After this line, we
+ // will try to print more human readable things to the terminal.
+ // But these have a higher probability to fail.
+ std::string stacktrace = CurrentStackTrace();
+ (void)write(STDERR_FILENO, stacktrace.c_str(), stacktrace.length());
+
+ // Abort the program.
+ struct sigaction sa;
+ sigemptyset(&sa.sa_mask);
+ sa.sa_flags = 0;
+ sa.sa_handler = SIG_DFL;
+ sigaction(SIGABRT, &sa, NULL);
+ abort();
+}
+
+void InstallStacktraceHandler() {
+ int handled_signals[] = {SIGSEGV, SIGABRT, SIGBUS, SIGILL, SIGFPE};
+
+ for (int i = 0; i < sizeof(handled_signals) / sizeof(int); i++) {
+ int sig = handled_signals[i];
+ struct sigaction sa;
+ struct sigaction osa;
+
+ sigemptyset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO | SA_RESETHAND;
+ sa.sa_sigaction = &StacktraceHandler;
+ if (sigaction(sig, &sa, &osa) != 0) {
+ char buf[128];
+ snprintf(buf, sizeof(buf),
+ "Warning, can't install backtrace signal handler for signal %d, "
+ "errno:%d \n",
+ sig, errno);
+ (void)write(STDERR_FILENO, buf, strlen(buf));
+ } else if (osa.sa_handler != SIG_DFL) {
+ char buf[128];
+ snprintf(buf, sizeof(buf),
+ "Warning, backtrace signal handler for signal %d overwrote "
+ "previous handler.\n",
+ sig);
+ (void)write(STDERR_FILENO, buf, strlen(buf));
+ }
+ }
+}
+
+#else
+void InstallStacktraceHandler() {}
+#endif // defined(TF_GENERATE_STACKTRACE)
+
+} // namespace testing
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/stacktrace_handler.h b/tensorflow/core/platform/stacktrace_handler.h
new file mode 100644
index 0000000..d36c82c
--- /dev/null
+++ b/tensorflow/core/platform/stacktrace_handler.h
@@ -0,0 +1,28 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+
+namespace tensorflow {
+namespace testing {
+
+// Installs signal handlers to print out stack trace.
+void InstallStacktraceHandler();
+
+} // namespace testing
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
diff --git a/tensorflow/core/platform/stacktrace_handler_test.cc b/tensorflow/core/platform/stacktrace_handler_test.cc
new file mode 100644
index 0000000..958c7de
--- /dev/null
+++ b/tensorflow/core/platform/stacktrace_handler_test.cc
@@ -0,0 +1,82 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Testing proper operation of the stacktrace handler.
+
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+#define READ_BUFFER_SIZE 1024
+
+TEST(StacktraceHandlerTest, GeneratesStacktrace) {
+ // Create a pipe to write/read the child stdout.
+ int test_pipe[2];
+ EXPECT_EQ(pipe(test_pipe), 0);
+
+ // Fork the process.
+ int test_pid = fork();
+
+ if (test_pid == 0) {
+ // Child process.
+ // Close the read end of the pipe, redirect stdout and sleep.
+ close(test_pipe[0]);
+ dup2(test_pipe[1], STDOUT_FILENO);
+ dup2(test_pipe[1], STDERR_FILENO);
+ sleep(10);
+ } else {
+ // Parent process.
+ // Close the write end of the pipe, wait a little and send SIGABRT to the
+ // child process. Then watch the pipe.
+ close(test_pipe[1]);
+ sleep(1);
+
+ // Send the signal.
+ kill(test_pid, SIGABRT);
+
+ // Read from the pipe.
+ char buffer[READ_BUFFER_SIZE];
+ std::string child_output = "";
+ while (true) {
+ int read_length = read(test_pipe[0], buffer, READ_BUFFER_SIZE);
+ if (read_length > 0) {
+ child_output += std::string(buffer, read_length);
+ } else {
+ break;
+ }
+ }
+ close(test_pipe[0]);
+
+ // Just make sure we can detect one of the calls in testing stack.
+ string test_stack_frame = "testing::internal::UnitTestImpl::RunAllTests()";
+
+ // Print the stack trace detected for information.
+ LOG(INFO) << "Output from the child process:";
+ LOG(INFO) << child_output;
+
+ EXPECT_NE(child_output.find(test_stack_frame), std::string::npos);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/test_main.cc b/tensorflow/core/platform/test_main.cc
index 96c88af..677114f 100644
--- a/tensorflow/core/platform/test_main.cc
+++ b/tensorflow/core/platform/test_main.cc
@@ -27,12 +27,14 @@
#include <iostream>
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
GTEST_API_ int main(int argc, char** argv) {
std::cout << "Running main() from test_main.cc\n";
+ tensorflow::testing::InstallStacktraceHandler();
testing::InitGoogleTest(&argc, argv);
for (int i = 1; i < argc; i++) {
if (tensorflow::StringPiece(argv[i]).starts_with("--benchmarks=")) {
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 6d6297d..47b1251 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -113,6 +113,29 @@
[Stack Overflow](http://stackoverflow.com/questions/tagged/tensorflow)
for possible solutions. You can skip reading the rest of this document.
+### GPU support
+
+If your Linux system has an NVIDIA® GPU and your TensorFlow Java program
+requires GPU acceleration, then add the following to the project's `pom.xml`
+instead:
+
+```xml
+<dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>libtensorflow</artifactId>
+ <version>1.4.0</version>
+</dependency>
+<dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>libtensorflow_jni_gpu</artifactId>
+ <version>1.4.0</version>
+</dependency>
+```
+
+GPU acceleration is available via Maven only for Linux and only if your system
+meets the
+@{$install_linux#determine_which_tensorflow_to_install$requirements for GPU}.
+
## Using TensorFlow with JDK
This section describes how to use TensorFlow using the `java` and `javac`
diff --git a/tensorflow/docs_src/performance/datasets_performance.md b/tensorflow/docs_src/performance/datasets_performance.md
index dd55849..4f95e17 100644
--- a/tensorflow/docs_src/performance/datasets_performance.md
+++ b/tensorflow/docs_src/performance/datasets_performance.md
@@ -224,7 +224,7 @@
to:
```
-dataset = files.apply(tf.data.contrib.parallel_interleave(
+dataset = files.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=FLAGS.num_parallel_readers))
```
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index dceaf71..97467c5 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3719,6 +3719,7 @@
srcs = ["training/session_manager_test.py"],
additional_deps = [
":array_ops",
+ ":control_flow_ops",
":client",
":client_testlib",
":errors",
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 00c38ea..909150e 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -154,8 +154,7 @@
sess: A tensorflow Session object.
"""
- _check_type(sess, (session.SessionInterface,
- monitored_session.MonitoredSession))
+ _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
self.session = sess
@@ -359,8 +358,7 @@
NotImplementedError: If a non-DirectSession sess object is received.
"""
- _check_type(sess, (session.SessionInterface,
- monitored_session.MonitoredSession))
+ _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
# The session being wrapped.
self._sess = sess
diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py
index 5240e0d..73e08ce 100644
--- a/tensorflow/python/debug/wrappers/framework_test.py
+++ b/tensorflow/python/debug/wrappers/framework_test.py
@@ -271,9 +271,9 @@
def testSessionInitInvalidSessionType(self):
"""Attempt to wrap a non-Session-type object should cause an exception."""
- sess = "not a session"
+ wrapper = TestDebugWrapperSessionBadAction(self._sess)
with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"):
- TestDebugWrapperSessionBadAction(sess)
+ TestDebugWrapperSessionBadAction(wrapper)
def testSessionInitBadActionValue(self):
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 8aec242..3173afc 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -133,6 +133,9 @@
"""Set a global eager mode seed for random ops."""
self._seed = seed
self._rng = random.Random(self._seed)
+ # Also clear the kernel cache, to reset any existing seeds
+ if self._context_handle is not None:
+ pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
def _internal_operation_seed(self):
"""Returns a fake operation seed.
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 487f1b0..d9c1c3c 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -380,6 +380,93 @@
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+ def testReduceSumAlongHWC(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2, 3])
+ output = array_ops.identity(reduce_sum)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if node.name.startswith('LayoutOptimizerTranspose'):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Three transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 1
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
+ def testReduceSumAlongNHW(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2])
+ output = array_ops.identity(reduce_sum)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if node.name.startswith('LayoutOptimizerTranspose'):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Three transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 1
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
+ def testReduceSumAlongC(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ reduce_sum = math_ops.reduce_sum(conv, axis=[3])
+ output = array_ops.identity(reduce_sum)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if node.name.startswith('LayoutOptimizerTranspose'):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Three transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 1
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
def testConcatWithControlDependency(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index 936da39..7650bfb 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -399,7 +399,7 @@
model.add(keras.layers.Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
- np.random.seed(1337)
+ np.random.seed(43)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=train_samples,
test_samples=test_samples,
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index 56aaa53..5a2903a 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -21,6 +21,7 @@
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -174,6 +175,17 @@
diff = rnd2 - rnd1
self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
+ def testEagerSeed(self):
+ with context.eager_mode():
+ # Ensure a context has been created
+ random_ops.random_normal([])
+ # Set the same seed twice and check that the values match
+ context.set_global_seed(42)
+ rnd1 = random_ops.random_normal([])
+ context.set_global_seed(42)
+ rnd2 = random_ops.random_normal([])
+ self.assertAllEqual(rnd1, rnd2)
+
class RandomUniformTest(test.TestCase):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 126e39a..00faf3f 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -131,6 +131,9 @@
self._init_set_name(name)
+ # Holds functions for creating regularizer ops.
+ self._regularizer_factories = []
+
# Determine variable scope.
scope = kwargs.get('_scope')
if scope:
@@ -291,6 +294,22 @@
inputs_hash = None
return self._per_input_updates.get(inputs_hash, [])
+ def _get_regularizer_factories(self):
+ try:
+ # Some subclasses of Layer do not use its constructor.
+ return self._regularizer_factories
+ except AttributeError:
+ self._regularizer_factories = []
+ return self._regularizer_factories
+
+ def _maybe_create_variable_regularizers(self):
+ """Creates added but uninstantiated regularizers."""
+ factories = self._get_regularizer_factories()
+ if factories:
+ for factory in factories:
+ factory()
+ factories[:] = []
+
@property
def losses(self):
"""Losses which are associated with this `Layer`.
@@ -302,6 +321,7 @@
Returns:
A list of tensors.
"""
+ self._maybe_create_variable_regularizers()
if context.in_eager_mode():
# _losses may only contain variable regularization losses when executing
# eagerly, and they have been saved as lambdas to be executed when
@@ -385,6 +405,7 @@
inputs_hash = layers_util.object_list_uid(inputs)
else:
inputs_hash = None
+ self._maybe_create_variable_regularizers()
return self._per_input_losses.get(inputs_hash, [])
def build(self, _):
@@ -479,17 +500,20 @@
instance is returned.
Raises:
- RuntimeError: If called in Eager mode with regularizers.
+ RuntimeError: If called in Eager mode with partioned variable
+ regularization.
"""
- if context.in_graph_mode():
+
+ in_graph_mode = context.in_graph_mode()
+ if in_graph_mode:
existing_variables = set(tf_variables.global_variables())
if dtype is None:
dtype = self.dtype or dtypes.float32
self._set_scope(None)
+ reuse = self.built or self._reuse
with vs.variable_scope(
- self._scope, reuse=(self.built or self._reuse),
- auxiliary_name_scope=False) as scope:
+ self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
with ops.name_scope(self._name_scope_name(scope)):
variable = vs.get_variable(name,
shape=shape,
@@ -498,39 +522,56 @@
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner)
- if context.in_graph_mode():
+
+ if in_graph_mode:
if (trainable and self.trainable
and variable not in tf_variables.trainable_variables()):
# A custom getter / variable scope overrode the trainable flag.
trainable = False
if variable in existing_variables:
+ # To match the behavior of tf.get_variable(), we only apply
+ # regularization if the variable is newly created.
return variable
- if regularizer:
- # To match the behavior of tf.get_variable(), we only
- # apply regularization if the variable is newly created.
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
+
+ if regularizer:
+ def regularizer_factory():
+ if context.in_graph_mode():
+ with vs.variable_scope(scope, reuse=reuse,
+ auxiliary_name_scope=False):
+ with ops.name_scope(self._name_scope_name(scope)):
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ with ops.colocate_with(v.op):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ if regularization is not None:
+ self.add_loss(regularization)
+ else:
+ with ops.colocate_with(variable.op):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(variable)
+ if regularization is not None:
+ self.add_loss(regularization)
else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
- elif regularizer:
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet supported when '
- 'executing eagerly. File a feature request is this is '
- 'important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested. This
- # makes losses responsive to variable updates when executing eagerly.
- self._losses.append(lambda: regularizer(variable))
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ raise RuntimeError(
+ 'Partitioned variable regularization is not yet '
+ 'supported when executing eagerly. File a feature request'
+ 'if this is important to you.')
+ # Save a zero-argument lambda which runs the regularizer on the
+ # variable, to be executed when `Layer.losses` is requested.
+ # This makes losses responsive to variable updates when
+ # executing eagerly.
+ self._losses.append(lambda: regularizer(variable))
+
+ if hasattr(self, '_defer_regularizers') and self._defer_regularizers:
+ # _defer_regularizers exists and is set to True if `build` was
+ # invoked in `__call__`: deferring regularizer construction
+ # prevents the regularizer from being created in an `init_scope`.
+ self._get_regularizer_factories().append(regularizer_factory)
+ else:
+ regularizer_factory()
+
if trainable:
self._trainable_weights.append(variable)
else:
@@ -629,7 +670,15 @@
except AttributeError:
pass
input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
- self.build(input_shapes)
+
+ # Signal to `add_variable` that regularizer construction should be
+ # deferred.
+ self._defer_regularizers = True
+ with ops.init_scope():
+ self.build(input_shapes)
+ # Create any regularizers added by `build`.
+ self._maybe_create_variable_regularizers()
+ self._defer_regularizers = False
try:
# Note: not all sub-classes of Layer call Layer.__init__ (especially
# the ones under tensorflow/python/keras). Hence we recompute this
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index c26b136..06ba214 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -531,6 +531,30 @@
self.assertEqual(layer2.my_var.name, 'name_3/my_var:0')
self.assertEqual(op2.name, 'name_3/my_op:0')
+ def testVariablesAreLiftedFromFunctionBuildingGraphs(self):
+ class MyLayer(base_layers.Layer):
+
+ def build(self, input_shape):
+ self.my_var = self.add_variable('my_var', (), dtypes.float32)
+ self.built = True
+
+ def call(self, inputs):
+ return inputs
+
+ outer_graph = ops.get_default_graph()
+ function_building_graph = ops.Graph()
+ function_building_graph._building_function = True
+ with outer_graph.as_default():
+ with function_building_graph.as_default():
+ layer = MyLayer()
+ # Create a variable by invoking build through __call__ and assert that
+ # it is both tracked and lifted into the outer graph.
+ inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+ layer.apply(inputs)
+ self.assertEqual(len(layer.variables), 1)
+ self.assertEqual(len(layer.trainable_variables), 1)
+ self.assertEqual(layer.variables[0].graph, outer_graph)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 74b4056..88d1ce5 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1494,20 +1494,17 @@
zero = ""
else:
zero = 0
- # Checking for boolean dtype to prevent attempting to run fill on the GPU
- # which does not have a boolean kernel registered.
- if context.in_eager_mode() and dtype != dtypes.bool:
- return fill(shape, constant(zero, dtype=dtype), name=name)
- try:
- if isinstance(shape, ops.Tensor):
- # TODO(apassos) this is required to reproduce the behavior from before
- # Tensors were iterable. It's a crutch.
- raise TypeError
- shape = tensor_shape.as_shape(shape)
- output = constant(zero, shape=shape, dtype=dtype, name=name)
- except (TypeError, ValueError):
- shape = ops.convert_to_tensor(shape, dtype=dtypes.int32, name="shape")
- output = fill(shape, constant(zero, dtype=dtype), name=name)
+ if not isinstance(shape, ops.Tensor):
+ try:
+ # Go through tensor shapes to get int64-if-needed semantics
+ shape = constant_op._tensor_shape_tensor_conversion_function(
+ tensor_shape.TensorShape(shape))
+ except (TypeError, ValueError):
+ # Happens when shape is a list with tensor elements
+ shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
+ if not shape._shape_tuple():
+ shape = reshape(shape, [-1]) # Ensure it's a vector
+ output = fill(shape, constant(zero, dtype=dtype), name=name)
assert output.dtype.base_dtype == dtype
return output
@@ -1625,15 +1622,17 @@
dtype = dtypes.as_dtype(dtype).base_dtype
with ops.name_scope(name, "ones", [shape]) as name:
one = True if dtype == dtypes.bool else 1
- try:
- if isinstance(shape, ops.Tensor):
- raise TypeError(
- "preserving semantics from before tensors were iterable")
- shape = tensor_shape.as_shape(shape)
- output = constant(one, shape=shape, dtype=dtype, name=name)
- except (TypeError, ValueError):
- shape = ops.convert_to_tensor(shape, dtype=dtypes.int32, name="shape")
- output = fill(shape, constant(one, dtype=dtype), name=name)
+ if not isinstance(shape, ops.Tensor):
+ try:
+ # Go through tensor shapes to get int64-if-needed semantics
+ shape = constant_op._tensor_shape_tensor_conversion_function(
+ tensor_shape.TensorShape(shape))
+ except (TypeError, ValueError):
+ # Happens when shape is a list with tensor elements
+ shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
+ if not shape._shape_tuple():
+ shape = reshape(shape, [-1]) # Ensure it's a vector
+ output = fill(shape, constant(one, dtype=dtype), name=name)
assert output.dtype.base_dtype == dtype
return output
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index ab7a903..879c206 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -276,10 +276,6 @@
dtype=dtype,
constraint=constraint)
- # LINT.IfChange
- # _VariableFromResource inherits from ResourceVariable but
- # doesn't call the constructor, so changes here might need to be reflected
- # there.
# pylint: disable=unused-argument
def _init_from_args(self,
initial_value=None,
@@ -438,7 +434,8 @@
self._initializer_op = (
gen_resource_variable_ops.assign_variable_op(
self._handle,
- self._build_initializer_expr(initial_value),
+ self._try_guard_against_uninitialized_dependencies(
+ initial_value),
name=n))
with ops.name_scope("Read"), ops.colocate_with(self._handle):
# Manually assign reads to the handle's device to avoid log
@@ -522,7 +519,6 @@
self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
self._graph_element = self.value()
self._constraint = None
- # LINT.ThenChange(//tensorflow/python/eager/graph_callable.py)
def __nonzero__(self):
return self.__bool__()
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 07796b2..e0cf1bf 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -127,8 +127,8 @@
Returns:
A function to encapsulate a set of variables which should be created once
- and reused. An enclosing scope will created, either where `make_template`
- is called, or wherever the result is called, depending on the value of
+ and reused. An enclosing scope will be created either when `make_template`
+ is called or when the result is called, depending on the value of
`create_scope_now_`. Regardless of the value, the first time the template
is called it will enter the scope with no reuse, and call `func_` to create
variables, which are guaranteed to be unique. All subsequent calls will
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index e0748d8..b258556 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -362,7 +362,8 @@
# using their initialized_value() method.
self._initializer_op = state_ops.assign(
self._variable,
- self._build_initializer_expr(self._initial_value),
+ self._try_guard_against_uninitialized_dependencies(
+ self._initial_value),
validate_shape=validate_shape).op
# TODO(vrv): Change this class to not take caching_device, but
@@ -781,88 +782,142 @@
setattr(Variable, operator, _run_op)
- def _build_initializer_expr(self, initial_value):
- """Build an expression suitable to initialize a variable.
+ def _try_guard_against_uninitialized_dependencies(self, initial_value):
+ """Attempt to guard against dependencies on uninitialized variables.
- Replace references to variables in initial_value with references to the
- variable initial values instead.
+ Replace references to variables in `initial_value` with references to the
+ variable's initialized values. The initialized values are essentially
+ conditional TensorFlow graphs that return a variable's value if it is
+ initialized or its `initial_value` if it hasn't been initialized. This
+ replacement is done on a best effort basis:
+
+ - If the `initial_value` graph contains cycles, we don't do any
+ replacements for that graph.
+ - If the variables that `initial_value` depends on are not present in the
+ `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
+
+ In these cases, it is up to the caller to ensure that the `initial_value`
+ graph uses initialized variables or that they guard access to variables
+ using their `initialized_value` method.
Args:
- initial_value: original expression
+ initial_value: `Tensor`. The initial value.
Returns:
- A tensorflow expression suitable to initialize a variable.
+ A `Tensor` suitable to initialize a variable.
+ Raises:
+ TypeError: If `initial_value` is not a `Tensor`.
"""
- if isinstance(initial_value, Variable):
- return initial_value.initialized_value()
- elif isinstance(initial_value, ops.Tensor):
- new_op = self._build_initializer_expr(initial_value.op)
- if new_op != initial_value.op:
- if isinstance(new_op, ops.Tensor):
- return new_op
- else:
- return ops.Tensor(new_op, initial_value.value_index,
- initial_value.dtype)
- else:
- return initial_value
- elif isinstance(initial_value, ops.Operation):
- if initial_value.node_def.op in [
- "IsVariableInitialized", "VarIsInitializedOp", "ReadVariableOp"
- ]:
- return initial_value
- if initial_value.node_def.op in ["Variable", "VariableV2", "VarHandleOp"]:
- return self._find_initialized_value_for_variable(initial_value)
- modified = False
- new_inputs = []
- for tensor in initial_value.inputs:
- new_tensor = self._build_initializer_expr(tensor)
- new_inputs.append(new_tensor)
- if new_tensor != tensor:
- modified = True
+ if not isinstance(initial_value, ops.Tensor):
+ raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
- if modified:
- new_name = initial_value.node_def.name + "_" + self.name
- new_name = new_name.replace(":", "_")
- new_op = initial_value.node_def.op
- new_op = new_op.replace("RefSwitch", "Switch")
- new_value = self.graph.create_op(
- new_op,
- new_inputs,
- # pylint: disable=protected-access
- initial_value._output_types,
- # pylint: enable=protected-access
- name=new_name,
- attrs=initial_value.node_def.attr)
- return new_value
- else:
- return initial_value
- else:
+ # Don't modify initial_value if it contains any cyclic dependencies.
+ def has_cycle(op, path):
+ """Detect cycles in the dependencies of `initial_value`."""
+ if op.name in path:
+ return True
+ path.add(op.name)
+ for op_input in op.inputs:
+ if has_cycle(op_input.op, path):
+ return True
+ for op_control_input in op.control_inputs:
+ if has_cycle(op_control_input, path):
+ return True
+ path.remove(op.name)
+ return False
+ if has_cycle(initial_value.op, path=set()):
return initial_value
+ return self._safe_initial_value_from_tensor(initial_value, op_cache={})
+
+ def _safe_initial_value_from_tensor(self, tensor, op_cache):
+ """Replace dependencies on variables with their initialized values.
+
+ Args:
+ tensor: A `Tensor`. The tensor to replace.
+ op_cache: A dict mapping operation names to `Operation`s. Used to memoize
+ the results so as to avoid creating redundant operations.
+ Returns:
+ A `Tensor` compatible with `tensor`. Any inputs that lead to variable
+ values will be replaced with a corresponding graph that uses the
+ variable's initialized values. This is done on a best-effort basis. If no
+ modifications need to be made then `tensor` will be returned unchanged.
+ """
+ op = tensor.op
+ new_op = op_cache.get(op.name)
+ if new_op is None:
+ new_op = self._safe_initial_value_from_op(op, op_cache)
+ op_cache[op.name] = new_op
+ return new_op.outputs[tensor.value_index]
+
+ def _safe_initial_value_from_op(self, op, op_cache):
+ """Replace dependencies on variables with their initialized values.
+
+ Args:
+ op: An `Operation`. The operation to replace.
+ op_cache: A dict mapping operation names to `Operation`s. Used to memoize
+ the results so as to avoid creating redundant operations.
+ Returns:
+ An `Operation` compatible with `op`. Any inputs that lead to variable
+ values will be replaced with a corresponding graph that uses the
+ variable's initialized values. This is done on a best-effort basis. If no
+ modifications need to be made then `op` will be returned unchanged.
+ """
+ op_type = op.node_def.op
+ if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
+ "ReadVariableOp"):
+ return op
+
+ # Attempt to find the initialized_value of any variable reference / handles.
+ # TODO(b/70206927): Fix handling of ResourceVariables.
+ if op_type in ("Variable", "VariableV2", "VarHandleOp"):
+ initialized_value = self._find_initialized_value_for_variable(op)
+ return op if initialized_value is None else initialized_value.op
+
+ # Recursively build initializer expressions for inputs.
+ modified = False
+ new_op_inputs = []
+ for op_input in op.inputs:
+ new_op_input = self._safe_initial_value_from_tensor(op_input, op_cache)
+ new_op_inputs.append(new_op_input)
+ modified = modified or (new_op_input != op_input)
+
+ # If at least one input was modified, replace the op.
+ if modified:
+ new_op_type = op_type
+ if new_op_type == "RefSwitch":
+ new_op_type = "Switch"
+ new_op_name = op.node_def.name + "_" + self.name
+ new_op_name = new_op_name.replace(":", "_")
+ return self.graph.create_op(
+ new_op_type, new_op_inputs,
+ op._output_types, # pylint: disable=protected-access
+ name=new_op_name, attrs=op.node_def.attr)
+
+ return op
+
def _find_initialized_value_for_variable(self, variable_op):
- """Find the initial value for a variable op.
+ """Find the initialized value for a variable op.
To do so, lookup the variable op in the variables collection.
Args:
- variable_op: a TensorFlow variable Operation
+ variable_op: A variable `Operation`.
Returns:
- The initial value for the variable.
+ A `Tensor` representing the initialized value for the variable or `None`
+ if the initialized value could not be found.
"""
try:
var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
- global_vars = self.graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- for var in global_vars:
- if var.name in var_names:
- return var.initialized_value()
- local_vars = self.graph.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
- for var in local_vars:
- if var.name == var_names:
- return var.initialized_value()
+ for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
+ ops.GraphKeys.LOCAL_VARIABLES):
+ for var in self.graph.get_collection(collection_name):
+ if var.name in var_names:
+ return var.initialized_value()
except AttributeError:
- # Return the variable itself when an incomplete user defined variable type
- # was put in the collection.
- return variable_op
- return variable_op
+ # Return None when an incomplete user-defined variable type was put in
+ # the collection.
+ return None
+ return None
# NOTE(mrry): This enables the Variable's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
diff --git a/tensorflow/python/profiler/internal/print_model_analysis_test.py b/tensorflow/python/profiler/internal/print_model_analysis_test.py
index 797c430..186c028 100644
--- a/tensorflow/python/profiler/internal/print_model_analysis_test.py
+++ b/tensorflow/python/profiler/internal/print_model_analysis_test.py
@@ -18,22 +18,13 @@
from __future__ import division
from __future__ import print_function
-from google.protobuf import text_format
-
-from tensorflow.core.profiler import tfprof_options_pb2
-from tensorflow.core.profiler import tfprof_output_pb2
-from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-# pylint: disable=g-bad-import-order
-# XXX: this depends on pywrap_tensorflow and must come later
-from tensorflow.python import pywrap_tensorflow as print_mdl
# pylint: disable=bad-whitespace
# pylint: disable=bad-continuation
@@ -69,407 +60,6 @@
x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
return x
- def testPrintModelAnalysis(self):
- opts = tfprof_options_pb2.OptionsProto()
- opts.max_depth = TEST_OPTIONS['max_depth']
- opts.min_bytes = TEST_OPTIONS['min_bytes']
- opts.min_micros = TEST_OPTIONS['min_micros']
- opts.min_params = TEST_OPTIONS['min_params']
- opts.min_float_ops = TEST_OPTIONS['min_float_ops']
- opts.order_by = TEST_OPTIONS['order_by']
- opts.step = -1
- for p in TEST_OPTIONS['account_type_regexes']:
- opts.account_type_regexes.append(p)
- for p in TEST_OPTIONS['start_name_regexes']:
- opts.start_name_regexes.append(p)
- for p in TEST_OPTIONS['trim_name_regexes']:
- opts.trim_name_regexes.append(p)
- for p in TEST_OPTIONS['show_name_regexes']:
- opts.show_name_regexes.append(p)
- for p in TEST_OPTIONS['hide_name_regexes']:
- opts.hide_name_regexes.append(p)
- opts.account_displayed_op_only = TEST_OPTIONS['account_displayed_op_only']
- for p in TEST_OPTIONS['select']:
- opts.select.append(p)
- opts.output = TEST_OPTIONS['output']
-
- with session.Session() as sess, ops.device('/cpu:0'):
- _ = self._BuildSmallModel()
- tfprof_pb = tfprof_output_pb2.GraphNodeProto()
- tfprof_pb.ParseFromString(
- print_mdl.PrintModelAnalysis(
- sess.graph.as_graph_def(add_shapes=True).SerializeToString(),
- b'',
- b'',
- b'scope',
- opts.SerializeToString()))
-
- expected_pb = tfprof_output_pb2.GraphNodeProto()
- text_format.Merge(r"""name: "_TFProfRoot"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 648
- children {
- name: "Conv2D"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 2
- }
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- }
- }
- input_shapes {
- key: 1
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW"
- exec_micros: 0
- requested_bytes: 0
- parameters: 648
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 648
- children {
- name: "DW/Assign"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- input_shapes {
- key: 1
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- children {
- name: "DW/Initializer/random_normal"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- children {
- name: "DW/Initializer/random_normal/RandomStandardNormal"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 4
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer/random_normal/mean"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer/random_normal/mul"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- input_shapes {
- key: 1
- value {
- dim {
- size: 1
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer/random_normal/shape"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- children {
- name: "DW/Initializer/random_normal/stddev"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- input_shapes {
- key: 1
- value {
- dim {
- size: 1
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 6
- }
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 7
- }
- children {
- name: "DW/read"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- input_shapes {
- key: 0
- value {
- dim {
- size: 6
- }
- dim {
- size: 6
- }
- dim {
- size: 3
- }
- dim {
- size: 6
- }
- }
- }
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 10
- }
- children {
- name: "zeros"
- exec_micros: 0
- requested_bytes: 0
- total_exec_micros: 0
- total_requested_bytes: 0
- total_parameters: 0
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 1
- }
- float_ops: 0
- total_float_ops: 0
- accelerator_exec_micros: 0
- cpu_exec_micros: 0
- total_accelerator_exec_micros: 0
- total_cpu_exec_micros: 0
- run_count: 0
- total_run_count: 0
- total_definition_count: 13""", expected_pb)
- self.assertEqual(expected_pb, tfprof_pb)
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index a379bd5..14ad9e5 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -164,13 +164,6 @@
model_analyzer.profile(
sess.graph, run_meta, options=opts)
- with gfile.Open(outfile, 'r') as f:
- # pylint: disable=line-too-long
- self.assertEqual(
- 'node name | # parameters | # float_ops | assigned devices | op types | op count (run|defined) | input shapes\n_TFProfRoot (--/451 params, --/11.34k flops, _kTFScopeParent, --/8|--/36, )\n Conv2D (0/0 params, 5.83k/5.83k flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Conv2D, 1/1|1/1, 0:2x6x6x3|1:3x3x3x6)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Conv2D, 1/1|1/1, 0:2x3x3x6|1:2x2x6x12)\n DW (3x3x3x6, 162/162 params, 0/324 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:3x3x3x6|1:3x3x3x6)\n DW/Initializer (0/0 params, 0/324 flops, _kTFScopeParent, 0/0|1/7, )\n DW/Initializer/random_normal (0/0 params, 162/324 flops, Add, 0/0|1/6, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/mul (0/0 params, 162/162 flops, Mul, 0/0|1/1, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Identity, 1/1|1/1, 0:3x3x3x6)\n DW2 (2x2x6x12, 288/288 params, 0/576 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW2/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:2x2x6x12|1:2x2x6x12)\n DW2/Initializer (0/0 params, 0/576 flops, _kTFScopeParent, 0/0|1/7, )\n DW2/Initializer/random_normal (0/0 params, 288/576 flops, Add, 0/0|1/6, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/mul (0/0 params, 288/288 flops, Mul, 0/0|1/1, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Identity, 1/1|1/1, 0:2x2x6x12)\n ScalarW (1, 1/1 params, 0/2 flops, VariableV2|_trainable_variables, 0/0|1/10, )\n ScalarW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer (0/0 params, 0/2 flops, _kTFScopeParent, 0/0|1/7, )\n ScalarW/Initializer/random_normal (0/0 params, 1/2 flops, Add, 0/0|1/6, 0:1|1:1)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:0)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/mul (0/0 params, 1/1 flops, Mul, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/read (0/0 params, 0/0 flops, Identity, 0/0|1/1, 0:1)\n _retval_Conv2D_1_0_0 (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|_retval_Conv2D_1_0_0, 1/1|1/1, )\n init (0/0 params, 0/0 flops, NoOp, 0/0|1/1, 0:1|1:3x3x3x6|2:2x2x6x12)\n zeros (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Const, 1/1|1/1, )\n',
- f.read())
- # pylint: enable=line-too-long
-
def testSimpleCodeView(self):
ops.reset_default_graph()
outfile = os.path.join(test.get_temp_dir(), 'dump')
@@ -376,7 +369,6 @@
self.assertLessEqual(len(tfprof_node.graph_nodes), last_occurrence)
last_occurrence = len(tfprof_node.graph_nodes)
- self.assertEqual(total_children, 15)
self.assertGreater(input_shapes, 0)
def testAdvisor(self):
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 82750e9..d97823c 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -20,6 +20,7 @@
%rename("%s") TFE_ContextListDevices;
%rename("%s") TFE_ContextAddFunction;
%rename("%s") TFE_ContextAddFunctionDef;
+%rename("%s") TFE_ContextClearCaches;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_RegisterExceptionClass;
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index 5879fd3..6670d93 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -26,6 +26,7 @@
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -504,6 +505,7 @@
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x")
+ # TODO(b/70206927): Use ResourceVariables once they are handled properly.
v_res = variables.Variable(1, name="v_res")
w_res = variables.Variable(
v_res,
@@ -556,6 +558,24 @@
self.assertEquals(1, sess.run(w_res))
self.assertEquals(3, sess.run(x_res))
+ def testPrepareSessionWithCyclicInitializer(self):
+ # Regression test. Previously Variable._build_initializer_expr would enter
+ # into an infinite recursion when the variable's initial_value involved
+ # cyclic dependencies.
+ with ops.Graph().as_default():
+ i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
+ v = variables.Variable(array_ops.identity(i), name="v")
+ with self.test_session():
+ self.assertEqual(False, variables.is_variable_initialized(v).eval())
+ sm = session_manager.SessionManager(
+ ready_op=variables.report_uninitialized_variables())
+ sess = sm.prepare_session("", init_op=v.initializer)
+ self.assertEqual(1, sess.run(v))
+ self.assertEqual(
+ True,
+ variables.is_variable_initialized(
+ sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
+
def testPrepareSessionDidNotInitLocalVariable(self):
with ops.Graph().as_default():
v = variables.Variable(1, name="v")
diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
index 80f2b59..6271fce 100755
--- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
@@ -73,13 +73,16 @@
# Zip up the .dll, LICENSE and include files for the C library.
mkdir -p ${DIR}/include/tensorflow/c
+mkdir -p ${DIR}/include/tensorflow/c/eager
mkdir -p ${DIR}/lib
cp bazel-bin/tensorflow/libtensorflow.so ${DIR}/lib/tensorflow.dll
cp tensorflow/c/c_api.h ${DIR}/include/tensorflow/c
+cp tensorflow/c/eager/c_api.h ${DIR}/include/tensorflow/c/eager
cp bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE ${DIR}/include/tensorflow/c
cd ${DIR}
zip -j libtensorflow-cpu-windows-$(uname -m).zip \
lib/tensorflow.dll \
+ include/tensorflow/c/eager/c_api.h \
include/tensorflow/c/c_api.h \
include/tensorflow/c/LICENSE
rm -rf lib include
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index 845bad5..dbc8159 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -55,7 +55,10 @@
pkg_tar(
name = "cheaders",
- files = ["//tensorflow/c:headers"],
+ files = [
+ "//tensorflow/c:headers",
+ "//tensorflow/c/eager:headers",
+ ],
package_dir = "include/tensorflow/c",
# Mark as "manual" till
# https://github.com/bazelbuild/bazel/issues/2352
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 04c24d7..44bfe9f 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -1,3 +1,28 @@
+# Android configs
+build:android --crosstool_top=//external:android/crosstool
+build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
+build:android_arm --config=android
+build:android_arm --cpu=armeabi-v7a
+build:android_arm64 --config=android
+build:android_arm64 --cpu=arm64-v8a
+
+# Config to use a mostly-static build and disable modular op registration
+# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
+# By default, TensorFlow will build with a dependence on
+# //tensorflow:libtensorflow_framework.so.
+build:monolithic --define framework_shared_object=false
+
+# For projects which use TensorFlow as part of a Bazel build process, putting
+# nothing in a bazelrc will default to a monolithic build. The following line
+# opts in to modular op registration support by default.
+build --define framework_shared_object=true
+
+# Please note that MKL on MacOS or windows is still not supported.
+# If you would like to use a local MKL instead of downloading, please set the
+# environment variable "TF_MKL_ROOT" every time before build.
+build:mkl --define=using_mkl=true
+build:mkl -c opt
+
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true