Merge commit for internal changes
diff --git a/configure.py b/configure.py
index 05e3e80..cf16ef4 100644
--- a/configure.py
+++ b/configure.py
@@ -1122,12 +1122,16 @@
   write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
                               computecpp_toolkit_path)
 
+
 def set_trisycl_include_dir(environ_cp):
   """Set TRISYCL_INCLUDE_DIR."""
-  ask_trisycl_include_dir = (
-      'Please specify the location of the triSYCL include directory. (Use '
-      '--config=sycl_trisycl when building with Bazel) '
-      '[Default is %s]: ') % _DEFAULT_TRISYCL_INCLUDE_DIR
+
+  ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
+                             'include directory. (Use --config=sycl_trisycl '
+                             'when building with Bazel) '
+                             '[Default is %s]: '
+                            ) % (_DEFAULT_TRISYCL_INCLUDE_DIR)
+
   while True:
     trisycl_include_dir = get_from_env_or_user_or_default(
         environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
@@ -1201,46 +1205,10 @@
     raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
 
 
-def set_mkl():
-  write_to_bazelrc('build:mkl --define using_mkl=true')
-  write_to_bazelrc('build:mkl -c opt')
-  print(
-      'Add "--config=mkl" to your bazel command to build with MKL '
-      'support.\nPlease note that MKL on MacOS or windows is still not '
-      'supported.\nIf you would like to use a local MKL instead of '
-      'downloading, please set the environment variable \"TF_MKL_ROOT\" every '
-      'time before build.\n')
-
-
-def set_monolithic():
-  # Add --config=monolithic to your bazel command to use a mostly-static
-  # build and disable modular op registration support (this will revert to
-  # loading TensorFlow with RTLD_GLOBAL in Python). By default (without
-  # --config=monolithic), TensorFlow will build with a dependence on
-  # //tensorflow:libtensorflow_framework.so.
-  write_to_bazelrc('build:monolithic --define framework_shared_object=false')
-  # For projects which use TensorFlow as part of a Bazel build process, putting
-  # nothing in a bazelrc will default to a monolithic build. The following line
-  # opts in to modular op registration support by default:
-  write_to_bazelrc('build --define framework_shared_object=true')
-
-
-def create_android_bazelrc_configs():
-  # Flags for --config=android
-  write_to_bazelrc('build:android --crosstool_top=//external:android/crosstool')
-  write_to_bazelrc(
-      'build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain')
-  # Flags for --config=android_arm
-  write_to_bazelrc('build:android_arm --config=android')
-  write_to_bazelrc('build:android_arm --cpu=armeabi-v7a')
-  # Flags for --config=android_arm64
-  write_to_bazelrc('build:android_arm64 --config=android')
-  write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a')
-
-
 def set_grpc_build_flags():
   write_to_bazelrc('build --define grpc_no_ares=true')
 
+
 def set_windows_build_flags():
   if is_windows():
     # The non-monolithic build is not supported yet
@@ -1251,6 +1219,11 @@
     write_to_bazelrc('build --verbose_failures')
 
 
+def config_info_line(name, help_text):
+  """Helper function to print formatted help text for Bazel config options."""
+  print('\t--config=%-12s\t# %s' % (name, help_text))
+
+
 def main():
   # Make a copy of os.environ to be clear when functions and getting and setting
   # environment variables.
@@ -1336,10 +1309,7 @@
 
   set_grpc_build_flags()
   set_cc_opt_flags(environ_cp)
-  set_mkl()
-  set_monolithic()
   set_windows_build_flags()
-  create_android_bazelrc_configs()
 
   if workspace_has_any_android_rule():
     print('The WORKSPACE file has at least one of ["android_sdk_repository", '
@@ -1357,6 +1327,11 @@
       create_android_ndk_rule(environ_cp)
       create_android_sdk_rule(environ_cp)
 
+  print('Preconfigured Bazel build configs. You can use any of the below by '
+        'adding "--config=<>" to your build command. See tools/bazel.rc for '
+        'more details.')
+  config_info_line('mkl', 'Build with MKL support.')
+  config_info_line('monolithic', 'Config for mostly static monolithic build.')
 
 if __name__ == '__main__':
   main()
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 53df884..74190cb 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -117,3 +117,9 @@
         "//tensorflow/core:lib",
     ],
 )
+
+filegroup(
+    name = "headers",
+    srcs = ["c_api.h"],
+    visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 589afb9..5a5e5fe 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -98,7 +98,10 @@
 
 void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
   status->status = tensorflow::Status::OK();
-  tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+  {
+    tensorflow::mutex_lock ml(ctx->cache_mu);
+    tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+  }
   TF_Graph* graph = ctx->session->graph;
   TF_DeleteSession(ctx->session, status);
   TF_DeleteGraph(graph);
@@ -110,6 +113,11 @@
   return TF_SessionListDevices(ctx->session, status);
 }
 
+void TFE_ContextClearCaches(TFE_Context* ctx) {
+  tensorflow::mutex_lock ml(ctx->cache_mu);
+  tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
+}
+
 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
   tensorflow::Tensor tensor;
   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
@@ -489,8 +497,11 @@
   std::vector<tensorflow::Tensor> outputs(1);
   const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
   tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
-  tensorflow::KernelAndDevice* kernel =
-      tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
+  tensorflow::KernelAndDevice* kernel;
+  {
+    tensorflow::tf_shared_lock l(ctx->cache_mu);
+    kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
+  }
   if (kernel == nullptr) {
     const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
     kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
@@ -506,6 +517,7 @@
       delete kernel;
       return;
     }
+    tensorflow::mutex_lock ml(ctx->cache_mu);
     tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
   }
   std::vector<TFE_TensorHandle*> copied_tensors;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 7caab43..9b0fd03 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -17,6 +17,8 @@
 #define TENSORFLOW_C_EAGER_C_API_H_
 
 // C API extensions to experiment with eager execution of kernels.
+// WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be
+// stable and can change without notice.
 
 #include "tensorflow/c/c_api.h"
 
@@ -87,6 +89,10 @@
 TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
                                                             TF_Status* status);
 
+// Clears the internal caches in the TFE context. Useful when reseeding random
+// ops.
+TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
+
 // A handle to a tensor on a device.
 //
 // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 11b7a51..55a04d4 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -58,9 +58,10 @@
   // session->devices[i].
   std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
 
+  tensorflow::mutex cache_mu;
   std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
                      tensorflow::Fprint128Hasher>
-      kernel_cache;
+      kernel_cache GUARDED_BY(cache_mu);
 
   tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
     return pflr->GetFLR(d->name());
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 78777f3..f7c6cd2 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -248,7 +248,9 @@
     tags = ["optonly"],
     deps = [
         ":xla_test",
+        "//tensorflow/contrib/signal:signal_py",
         "//tensorflow/python:array_ops",
+        "//tensorflow/python:extra_py_tests_deps",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:spectral_ops",
diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py
index bdc38be..afb5fa4 100644
--- a/tensorflow/compiler/tests/fft_test.py
+++ b/tensorflow/compiler/tests/fft_test.py
@@ -21,8 +21,10 @@
 import itertools
 
 import numpy as np
+import scipy.signal as sps
 
 from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.contrib.signal.python.ops import spectral_ops as signal
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import spectral_ops
@@ -76,6 +78,29 @@
         value = sess.run(out, {ph: data})
         self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
 
+  def testContribSignalSTFT(self):
+    ws = 512
+    hs = 128
+    dims = (ws * 20,)
+    shape = BATCH_DIMS + dims
+    data = np.arange(np.prod(shape)) / np.prod(dims)
+    np.random.seed(123)
+    np.random.shuffle(data)
+    data = np.reshape(data.astype(np.float32), shape)
+    window = sps.get_window("hann", ws)
+    expected = sps.stft(
+        data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2]
+    expected = np.swapaxes(expected, -1, -2)
+    expected *= window.sum()  # scipy divides by window sum
+    with self.test_session() as sess:
+      with self.test_scope():
+        ph = array_ops.placeholder(
+            dtypes.as_dtype(data.dtype), shape=data.shape)
+        out = signal.stft(ph, ws, hs)
+
+      value = sess.run(out, {ph: data})
+      self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
+
   def testFFT(self):
     self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
                           spectral_ops.fft)
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index dd67a1d..f76e214 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -37,6 +37,8 @@
 
 namespace {
 
+using xla::StatusOr;
+
 const char* const kArgOp = "_Arg";
 const char* const kRetValOp = "_Retval";
 
@@ -76,6 +78,20 @@
   std::unordered_set<Node*> nodes;
 };
 
+// Comparison function used for sorting nodes consistently.
+// a) resource variables are last, and
+// b) sort lexicographically by name (for deterministic output).
+struct NodeCmp {
+  bool operator()(const Node* lhs, const Node* rhs) const {
+    bool lhs_is_resource =
+        lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
+    bool rhs_is_resource =
+        rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
+    return std::tie(lhs_is_resource, lhs->name()) <
+           std::tie(rhs_is_resource, rhs->name());
+  }
+};
+
 // Returns a textual representation of the names of the nodes in the input.
 template <typename T>
 string NodesToString(const T& nodes) {
@@ -141,7 +157,7 @@
   return Status::OK();
 }
 
-xla::StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
+StatusOr<Node*> AddNode(const NodeDef& node_def, Graph* graph) {
   Status status;
   Node* inserted_node = graph->AddNode(node_def, &status);
   if (!status.ok()) {
@@ -150,7 +166,7 @@
   return inserted_node;
 }
 
-xla::StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
+StatusOr<Node*> BuildArgNode(Graph* graph, DataType type, int index) {
   NodeDef arg_def;
   NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp);
   builder.Attr("T", type);
@@ -159,7 +175,7 @@
   return AddNode(arg_def, graph);
 }
 
-xla::StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
+StatusOr<Node*> BuildRetvalNode(Graph* graph, DataType type, int index) {
   NodeDef ret_def;
   ret_def.set_op(kRetValOp);
   ret_def.set_name(strings::StrCat(kRetValOp, index));
@@ -310,16 +326,9 @@
   }
   frame->args = std::move(args);
 
-  // Order the arguments so that:
-  // a) resource variables are last, and
-  // b) sort lexicographically by name (for deterministic output).
-  std::sort(frame->args.begin(), frame->args.end(),
-            [](const Arg& a, const Arg& b) {
-              bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE);
-              bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE);
-              return std::tie(a_is_resource, a.enter->name()) <
-                     std::tie(b_is_resource, b.enter->name());
-            });
+  std::sort(
+      frame->args.begin(), frame->args.end(),
+      [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); });
 
   if (frame->loop_cond == nullptr) {
     return errors::InvalidArgument("Loop ", frame->name,
@@ -542,18 +551,6 @@
   // Returns a textual representation of the Branch b.
   static string Branch_Name(FunctionalizeCond::Branch b);
 
-  // Comparison function used for sorting nodes consistently.
-  struct CondCmp {
-    bool operator()(const Node* lhs, const Node* rhs) const {
-      bool lhs_is_resource =
-          lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false;
-      bool rhs_is_resource =
-          rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false;
-      return std::tie(lhs_is_resource, lhs->name()) <
-             std::tie(rhs_is_resource, rhs->name());
-    }
-  };
-
   // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf
   // nodes. That is, attempt to transform every remaining switch and merge nodes
   // in the graph into XlaIf nodes.
@@ -571,6 +568,13 @@
     int count;
   };
 
+  struct PredicateSwitches {
+    explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}
+
+    Node* predicate;
+    std::vector<Node*> switches;
+  };
+
   FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
       : library_(library), graph_(graph) {}
 
@@ -579,18 +583,24 @@
   // extract into XlaIf nodes.
   Status FunctionalizeInternal();
 
-  // Converts a Merge node to a XlaIf. This encapsulates the process of
-  // extracting the bodies needed for the then and else branch, creates a XlaIf
-  // node, removing the nodes of the branches from the graph and replacing the
-  // merge node with a XlaIf.
-  Status ConvertCorrespondingMergeToXlaIf(
-      const std::vector<Node*>& switch_nodes,
-      const std::vector<Node*>& merge_nodes, Node* predicate);
+  // Determines the branch_map (mapping from node to branch of cond) and
+  // frontier (the nodes where the cond ends).
+  StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
+                     std::unordered_set<Node*>>>
+  DetermineBranchMapAndFrontier(const std::vector<Node*>& switches);
+
+  // Returns XlaIf node created from subgraph of merge and switch nodes. This
+  // encapsulates the process of extracting the bodies needed for the then and
+  // else branch, creates a XlaIf node, removing the nodes of the branches from
+  // the graph and replacing the merge node with a XlaIf.
+  StatusOr<Node*> ConvertToXlaIf(const std::vector<Node*>& switch_nodes,
+                                 const std::vector<Node*>& merge_nodes,
+                                 Node* predicate);
 
   // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
-  xla::StatusOr<Node*> BuildAndAddXlaIfOp(
-      const std::vector<Node*>& switch_nodes,
-      const std::vector<Node*>& merge_nodes, Node* predicate);
+  StatusOr<Node*> BuildAndAddXlaIfOp(const std::vector<Node*>& switch_nodes,
+                                     const std::vector<Node*>& merge_nodes,
+                                     Node* predicate);
 
   // Extracts a function body corresponding to the given input edge of the merge
   // node.
@@ -605,18 +615,26 @@
   // Adds all output edges from the `if_node`.
   Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
 
-  // Returns the switches of graph_ in postorder. Dead switch nodes are skipped
-  // and removed from the graph.
-  std::vector<Node*> DetermineSwitchOrder();
+  // Returns the switches of graph_ (along with grouping predicates) in
+  // postorder. Dead switch nodes are skipped and removed from the graph.
+  std::vector<PredicateSwitches> DeterminePredicateSwitchOrder();
 
   // Update the state for destination based on the state of source and the node
   // being updated.
   Status Join(const ForwardFlowNode& src_state, const Node* dst,
               ForwardFlowNode* dst_state);
 
-  // Validates that the branch_map and frontier of nodes for the conditional
+  // Ensure that all nodes in the branch_map are dominated by the switch
+  // nodes. Returns nodes that are not dominated by the switches but are a
+  // control dependency of a node in the cond, and remove such control
+  // dependencies.
+  StatusOr<std::vector<Node*>> EnsureDominanceAndReturnNonDominatedControlNodes(
+      const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
+      const std::vector<Node*>& switches);
+
+  // Validates that the frontier of nodes for the conditional
   // section are as expected.
-  Status ValidBranchMapAndFrontier(
+  Status ValidateFrontier(
       const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
       const std::unordered_set<Node*>& frontier);
 
@@ -645,23 +663,11 @@
   return branch_name[b];
 }
 
-Status FunctionalizeCond::ValidBranchMapAndFrontier(
+Status FunctionalizeCond::ValidateFrontier(
     const std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>&
         branch_map,
     const std::unordered_set<Node*>& frontier) {
   std::unordered_set<const Node*> pending[kNumBranchTypes];
-  for (const auto& kv : branch_map) {
-    if (kv.second.count != kv.first->in_edges().size()) {
-      return errors::FailedPrecondition("Value ", kv.first->DebugString(),
-                                        " not dominated by switch nodes.");
-    }
-    if (VLOG_IS_ON(1)) {
-      // Append attribute to the graph if running with logging to make the
-      // changes clearer in the visualization.
-      kv.first->AddAttr("_XlaFunctionalizeBranch",
-                        Branch_Name(kv.second.branch));
-    }
-  }
   for (Node* n : frontier) {
     pending[branch_map.at(n).branch].insert(n);
   }
@@ -681,6 +687,10 @@
           ") in both Else and Then branch should be in Both.");
     }
   }
+  if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
+      pending[kElseBranch].empty()) {
+    return errors::Internal("Unexpected empty frontier for switch nodes");
+  }
   return Status::OK();
 }
 
@@ -707,7 +717,8 @@
   return Status::OK();
 }
 
-std::vector<Node*> FunctionalizeCond::DetermineSwitchOrder() {
+std::vector<FunctionalizeCond::PredicateSwitches>
+FunctionalizeCond::DeterminePredicateSwitchOrder() {
   std::vector<Node*> dead_switches;
   std::vector<Node*> switch_order;
   DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) {
@@ -725,25 +736,12 @@
     graph_->RemoveNode(n);
   }
 
-  return switch_order;
-}
-
-Status FunctionalizeCond::FunctionalizeInternal() {
-  std::vector<Node*> switch_order = DetermineSwitchOrder();
-  // If there are no switch nodes, then terminate.
+  std::vector<PredicateSwitches> predicate_switch_order;
   if (switch_order.empty()) {
-    return Status::OK();
+    return predicate_switch_order;
   }
 
-  struct PredicateSwitches {
-    explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}
-
-    Node* predicate;
-    std::vector<Node*> switches;
-  };
-
   // Merge Switch nodes with common predicate.
-  std::vector<PredicateSwitches> predicate_switch_order;
   std::unordered_map<Node*, int> predicate_index;
   // The nodes in switch_order are in reverse topological order, but the
   // clustered switches need not be (i.e., when considered as a cluster one
@@ -759,71 +757,145 @@
     }
     predicate_switch_order[predicate_index[pred]].switches.push_back(*it);
   }
+  return predicate_switch_order;
+}
+
+StatusOr<std::vector<Node*>>
+FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes(
+    const std::unordered_map<Node*, ForwardFlowNode>& branch_map,
+    const std::vector<Node*>& switches) {
+  std::vector<Node*> old_control_nodes;
+  for (const auto& kv : branch_map) {
+    if (kv.second.count != kv.first->in_edges().size()) {
+      std::vector<const Edge*> delete_edges;
+      for (const Edge* in : kv.first->in_edges()) {
+        auto it = branch_map.find(in->src());
+        if (it == branch_map.end()) {
+          if (in->IsControlEdge()) {
+            old_control_nodes.push_back(in->src());
+            delete_edges.push_back(in);
+          } else {
+            if (IsSwitch(in->src())) {
+              if (std::find(switches.begin(), switches.end(), in->src()) ==
+                  switches.end()) {
+                return errors::Internal(
+                    "Unexpected switch node found during flow forward.");
+              }
+              continue;
+            }
+            return errors::InvalidArgument(
+                "Value ", kv.first->name(), "'s input, ", in->src()->name(),
+                ", is not dominated by switch nodes ", NodesToString(switches));
+          }
+        }
+      }
+      // Remove control edges from nodes that are not dominated by the switch
+      // nodes. New control dependencies will be added between these nodes and
+      // the XlaIf node inserted.
+      for (const Edge* e : delete_edges) {
+        graph_->RemoveEdge(e);
+      }
+    }
+  }
+  return old_control_nodes;
+}
+
+StatusOr<
+    std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
+              std::unordered_set<Node*>>>
+FunctionalizeCond::DetermineBranchMapAndFrontier(
+    const std::vector<Node*>& switches) {
+  std::unordered_map<Node*, ForwardFlowNode> branch_map;
+  std::unordered_set<Node*> frontier;
+  std::vector<Node*> stack = switches;
+  std::vector<bool> visited(graph_->num_node_ids(), false);
+  while (!stack.empty()) {
+    Node* n = stack.back();
+    stack.pop_back();
+
+    if (visited[n->id()]) {
+      continue;
+    }
+    visited[n->id()] = true;
+
+    // Propagate branch state along each edge of a switch node.
+    bool sink_only = true;
+    for (const Edge* e : n->out_edges()) {
+      Node* out = e->dst();
+      if (!out->IsOp()) {
+        continue;
+      }
+      sink_only = false;
+      // Propagate branch information.
+      ForwardFlowNode& ffn = branch_map[out];
+      if (IsSwitch(n)) {
+        int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
+        TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn));
+      } else {
+        TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn));
+      }
+      if (IsMerge(out)) {
+        if (out->in_edges().size() == ffn.count) {
+          frontier.insert(out);
+        }
+      } else if (!visited[out->id()]) {
+        stack.push_back(out);
+      }
+    }
+    if (sink_only) {
+      if (!IsIdentity(n)) {
+        VLOG(1) << "Feeding into sink: " << n->DebugString();
+      }
+    }
+  }
+
+  if (VLOG_IS_ON(2)) {
+    for (const auto& kv : branch_map) {
+      // Append attribute to the graph if running with logging to make the
+      // changes clearer in the visualization.
+      kv.first->AddAttr("_XlaFunctionalizeBranch",
+                        Branch_Name(kv.second.branch));
+    }
+  }
+  return std::make_pair(std::move(branch_map), std::move(frontier));
+}
+
+Status FunctionalizeCond::FunctionalizeInternal() {
+  std::vector<PredicateSwitches> predicate_switch_order =
+      DeterminePredicateSwitchOrder();
 
   // Iterate from innermost set of clustered switches to outermost, replacing
   // matching switch->merge subgraphs with single XlaIf nodes.
   for (auto it = predicate_switch_order.rbegin();
        it != predicate_switch_order.rend(); ++it) {
     auto& ps = *it;
-    VLOG(3) << "Flow down from: " << ps.predicate->name() << " -> "
-            << NodesToString(ps.switches);
+    VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " ("
+            << ps.predicate->name() << ")";
 
     std::unordered_map<Node*, ForwardFlowNode> branch_map;
     std::unordered_set<Node*> frontier;
+    TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
+                        DetermineBranchMapAndFrontier(ps.switches));
 
-    std::vector<Node*> stack = ps.switches;
-    std::vector<bool> visited(graph_->num_node_ids(), false);
-    while (!stack.empty()) {
-      Node* n = stack.back();
-      stack.pop_back();
-
-      if (visited[n->id()]) {
-        continue;
-      }
-      visited[n->id()] = true;
-
-      // Propagate branch state along each edge of a switch node.
-      bool sink_only = true;
-      for (const Edge* e : n->out_edges()) {
-        Node* out = e->dst();
-        if (!out->IsOp()) {
-          continue;
-        }
-        sink_only = false;
-        // Propagate branch information.
-        ForwardFlowNode& ffn = branch_map[out];
-        if (IsSwitch(n)) {
-          int index = e->IsControlEdge() ? Branch::kNeither : e->src_output();
-          TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn));
-        } else {
-          TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn));
-        }
-        if (IsMerge(out)) {
-          if (out->in_edges().size() == ffn.count) {
-            frontier.insert(out);
-          }
-        } else if (!visited[out->id()] && ffn.count == out->in_edges().size()) {
-          // If all predecessors are dominated by the switch nodes, then add
-          // the output to the stack.
-          stack.push_back(out);
-        }
-      }
-      if (sink_only) {
-        if (!IsIdentity(n)) {
-          VLOG(1) << "Feeding into sink: " << n->DebugString();
-        }
-      }
-    }
-
-    TF_RETURN_IF_ERROR(ValidBranchMapAndFrontier(branch_map, frontier));
     VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): "
             << dump_graph::DumpGraphToFile("functionalize_bc", *graph_);
+    TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
+
     std::vector<Node*> switch_nodes(ps.switches);
-    std::sort(switch_nodes.begin(), switch_nodes.end(), CondCmp());
+    std::sort(switch_nodes.begin(), switch_nodes.end(), NodeCmp());
     std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
-    std::sort(merge_nodes.begin(), merge_nodes.end(), CondCmp());
-    TF_RETURN_IF_ERROR(ConvertCorrespondingMergeToXlaIf(
-        switch_nodes, merge_nodes, ps.predicate));
+    std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
+    TF_ASSIGN_OR_RETURN(std::vector<Node*> old_control_nodes,
+                        EnsureDominanceAndReturnNonDominatedControlNodes(
+                            branch_map, ps.switches));
+
+    TF_ASSIGN_OR_RETURN(
+        Node * if_node,
+        ConvertToXlaIf(switch_nodes, merge_nodes, ps.predicate));
+    for (Node* old : old_control_nodes) {
+      graph_->AddControlEdge(old, if_node);
+    }
+
     for (auto& del_kv : branch_map) {
       graph_->RemoveNode(del_kv.first);
     }
@@ -836,7 +908,7 @@
   return Status::OK();
 }
 
-xla::StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
+StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
     const std::vector<Node*>& switch_nodes,
     const std::vector<Node*>& merge_nodes, Node* predicate) {
   VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input "
@@ -917,7 +989,7 @@
   }
 
   std::vector<Node*> stack;
-  stack.reserve(switch_nodes.size());
+  stack.reserve(merge_nodes.size());
   for (int j = 0; j < merge_nodes.size(); ++j) {
     Node* node = merge_nodes[j];
     TF_ASSIGN_OR_RETURN(node_map.at(node->id()),
@@ -988,10 +1060,10 @@
   return Status::OK();
 }
 
-Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf(
+StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
     const std::vector<Node*>& switch_nodes,
     const std::vector<Node*>& merge_nodes, Node* predicate) {
-  VLOG(1) << "ConvertMergeToXlaIf for " << NodesToString(switch_nodes) << " -> "
+  VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> "
           << NodesToString(merge_nodes);
 
   // Extract bodies and builds a If operator.
@@ -1000,7 +1072,7 @@
   TF_RETURN_IF_ERROR(AddInputEdges(switch_nodes, predicate, if_node));
   TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
 
-  return Status::OK();
+  return if_node;
 }
 
 Status FunctionalizeCond::Functionalize(Graph* graph,
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index f373684..523169f 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -318,4 +318,8 @@
   return std::move(literal);
 }
 
+StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
+  return local_service_->ReplicaNumberToDeviceOrdinal(replica_number);
+}
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 3ca0d2e..19fd14f 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -176,6 +176,13 @@
   StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocal(
       const Shape& shape, int device_ordinal);
 
+  // Returns the device ordinal that corresponds to the given replica number.
+  //
+  // This returns an error if there is not a one-to-one correspondence of
+  // replicas to device ordinals, but is useful as a short term mechanism for
+  // the "easy" case where a single replica is a single device.
+  StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
+
   // Returns the platform that the underlying service targets.
   perftools::gputools::Platform* platform() const;
 
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index 03e403b..2b87cb0 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -21,9 +21,57 @@
 
 namespace swig {
 
-void TransferToInfeedLocal(const Literal& literal) {
-  LocalClient* client = ClientLibrary::LocalClientOrDie();
-  TF_CHECK_OK(client->TransferToInfeedLocal(literal, /*device_ordinal=*/0));
+// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of
+// device handles instead of needing to set the number of replicas at XLA
+// service initialization time.
+tensorflow::mutex g_replica_count_mutex(tensorflow::LINKER_INITIALIZED);
+int g_replica_count = 1;
+bool g_local_client_created = false;
+
+Status InitializeReplicaCount(int replica_count) {
+  if (replica_count < 1) {
+    return InvalidArgument("Replica count must be >= 1; got %d.",
+                           replica_count);
+  }
+  tensorflow::mutex_lock lock(g_replica_count_mutex);
+  if (g_local_client_created) {
+    return FailedPrecondition(
+        "Attempted to set the replica count to %d, but a local XLA service was "
+        "previously created with a replica count of %d.",
+        replica_count, g_replica_count);
+  }
+  g_replica_count = replica_count;
+  return Status::OK();
+}
+
+int GetReplicaCount() {
+  tensorflow::mutex_lock lock(g_replica_count_mutex);
+  return g_replica_count;
+}
+
+LocalClient* GetOrCreateLocalClient() {
+  LocalClientOptions options;
+  {
+    tensorflow::mutex_lock lock(g_replica_count_mutex);
+    options.set_number_of_replicas(g_replica_count);
+    g_local_client_created = true;
+  }
+  return ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie();
+}
+
+Status TransferToInfeedLocal(const Literal& literal) {
+  VLOG(1) << "Infeeding literal without replica number.";
+  LocalClient* client = GetOrCreateLocalClient();
+  return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0);
+}
+
+Status TransferToInfeedLocalReplica(const Literal& literal,
+                                    int replica_number) {
+  VLOG(1) << "Infeeding literal to replica number: " << replica_number;
+  LocalClient* client = GetOrCreateLocalClient();
+  TF_ASSIGN_OR_RETURN(int device_ordinal,
+                      client->ReplicaNumberToDeviceOrdinal(replica_number));
+  return client->TransferToInfeedLocal(literal, device_ordinal);
 }
 
 LocalShapedBuffer::LocalShapedBuffer(
@@ -37,7 +85,7 @@
 
 /* static */
 LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) {
-  LocalClient* client = ClientLibrary::LocalClientOrDie();
+  LocalClient* client = GetOrCreateLocalClient();
   std::unique_ptr<ScopedShapedBuffer> buf =
       client
           ->LiteralToShapedBuffer(argument,
@@ -48,7 +96,7 @@
 }
 
 std::unique_ptr<Literal> LocalShapedBuffer::ToLiteral() const {
-  LocalClient* client = ClientLibrary::LocalClientOrDie();
+  LocalClient* client = GetOrCreateLocalClient();
   return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie();
 }
 
@@ -56,43 +104,95 @@
     std::unique_ptr<LocalExecutable> executable)
     : executable_(std::move(executable)) {}
 
-std::unique_ptr<Literal> CompiledLocalComputation::Execute(
+StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
     const std::vector<Literal>& arguments) {
-  LocalClient* client = ClientLibrary::LocalClientOrDie();
+  LocalClient* client = GetOrCreateLocalClient();
 
-  // Transfer arguments in
-  std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
-  scoped_buffers.reserve(arguments.size());
-  for (const Literal& argument : arguments) {
-    scoped_buffers.push_back(
-        client
-            ->LiteralToShapedBuffer(argument,
-                                    /*device_ordinal=*/0,
-                                    client->backend().memory_allocator())
-            .ConsumeValueOrDie());
+  // Each replica populates a StatusOr result, but only replica zero actually
+  // retrieves its literal value.
+  std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
+  {
+    tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
+                                        GetReplicaCount());
+
+    for (int replica = 0; replica < GetReplicaCount(); ++replica) {
+      pool.Schedule([this, client, replica, &arguments, &results] {
+        StatusOr<int> device_ordinal_status =
+            client->ReplicaNumberToDeviceOrdinal(replica);
+        if (!device_ordinal_status.ok()) {
+          results[replica] = device_ordinal_status.status();
+          return;
+        }
+        const int device_ordinal = device_ordinal_status.ValueOrDie();
+        VLOG(3) << "Replica " << replica
+                << " mapped to device ordinal for execution: "
+                << device_ordinal;
+        // Transfer arguments in
+        std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
+        scoped_buffers.reserve(arguments.size());
+        for (const Literal& argument : arguments) {
+          StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed =
+              client->LiteralToShapedBuffer(
+                  argument, device_ordinal,
+                  client->backend().memory_allocator());
+          if (!pushed.ok()) {
+            results[replica] = pushed.status();
+            return;
+          }
+          scoped_buffers.push_back(std::move(pushed).ValueOrDie());
+        }
+
+        // Execute
+        std::vector<const ShapedBuffer*> argument_buffers;
+        argument_buffers.reserve(scoped_buffers.size());
+        for (auto& buffer : scoped_buffers) {
+          argument_buffers.push_back(buffer.get());
+        }
+
+        DeviceAssignment device_assignment =
+            client->backend()
+                .computation_placer()
+                ->AssignDevices(GetReplicaCount(), /*computation_count=*/1)
+                .ConsumeValueOrDie();
+
+        ExecutableRunOptions options;
+        options.set_device_ordinal(device_ordinal);
+        options.set_allocator(client->backend().memory_allocator());
+        options.set_inter_op_thread_pool(
+            client->backend().inter_op_thread_pool());
+        options.set_intra_op_thread_pool(
+            client->backend().eigen_intra_op_thread_pool_device());
+        options.set_device_assignment(&device_assignment);
+        StatusOr<std::unique_ptr<ScopedShapedBuffer>> result_buffer_status =
+            executable_->Run(argument_buffers, options);
+        if (!result_buffer_status.ok()) {
+          results[replica] = result_buffer_status.status();
+          return;
+        }
+
+        // Transfer result out
+        results[replica] =
+            client->ShapedBufferToLiteral(*result_buffer_status.ValueOrDie());
+      });
+    }
   }
 
-  // Execute
-  std::vector<const ShapedBuffer*> argument_buffers;
-  argument_buffers.reserve(scoped_buffers.size());
-  for (auto& buffer : scoped_buffers) {
-    argument_buffers.push_back(buffer.get());
+  for (int replica = 0; replica < GetReplicaCount(); ++replica) {
+    const auto& statusor = results[replica];
+    if (!statusor.ok()) {
+      return InternalError(
+          "Failed running replica %d (other replicas may have failed as well): "
+          "%s.",
+          replica, statusor.status().ToString().c_str());
+    }
   }
-  ExecutableRunOptions options;
-  options.set_allocator(client->backend().memory_allocator());
-  options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool());
-  options.set_intra_op_thread_pool(
-      client->backend().eigen_intra_op_thread_pool_device());
-  std::unique_ptr<ScopedShapedBuffer> result_buffer =
-      executable_->Run(argument_buffers, options).ConsumeValueOrDie();
 
-  // Transfer result out
-  return client->ShapedBufferToLiteral(*result_buffer).ConsumeValueOrDie();
+  return std::move(results[0]);
 }
 
 LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers(
     tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles) {
-  LocalClient* client = ClientLibrary::LocalClientOrDie();
+  LocalClient* client = GetOrCreateLocalClient();
 
   std::vector<const ShapedBuffer*> argument_buffers;
   argument_buffers.reserve(argument_handles.size());
@@ -123,7 +223,7 @@
     argument_shape_pointers.push_back(&argument_shape);
   }
 
-  LocalClient* client = ClientLibrary::LocalClientOrDie();
+  LocalClient* client = GetOrCreateLocalClient();
   ExecutableBuildOptions options;
   TF_ASSIGN_OR_RETURN(
       auto local_executable,
@@ -136,7 +236,7 @@
 }
 
 LocalComputationBuilder::LocalComputationBuilder(const string& computation_name)
-    : builder_(ClientLibrary::LocalClientOrDie(), computation_name) {}
+    : builder_(GetOrCreateLocalClient(), computation_name) {}
 
 StatusOr<LocalComputation*> LocalComputationBuilder::Build() {
   TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build());
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 8da4c99..1104d7f 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -27,11 +27,25 @@
 
 namespace swig {
 
-// Wraps the local client's infeed-transfer function, aborting on error.
+// Initializes the number of replicas that XLA will be initialized with (when
+// first obtaining a handle to the local XLA service). If this is called after
+// the handle to the local XLA service has been established, then an error is
+// returned.
+Status InitializeReplicaCount(int replica_count);
+
+// Returns the replica count that is currently set, regardless of whether the
+// local XLA service has been instantiated yet or not.
+int GetReplicaCount();
+
+// Wraps the local client's infeed-transfer function.
 //
-// TODO(leary) ideally we could return a value that would permit an appropriate
-// Python exception to be raised.
-void TransferToInfeedLocal(const Literal& literal);
+// The default device ordinal (0) is used.
+Status TransferToInfeedLocal(const Literal& literal);
+
+// Transfers the given literal to the infeed of the given replica.
+//
+// The replica number is resolved to an appropriate device ordinal.
+Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);
 
 // Wraps a ScopedShapedBuffer produced by copying a literal "to
 // device," i.e. copying a literal to a scoped buffer via the local
@@ -56,7 +70,8 @@
 class CompiledLocalComputation {
  public:
   CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
-  std::unique_ptr<Literal> Execute(const std::vector<Literal>& arguments);
+  StatusOr<std::unique_ptr<Literal> > Execute(
+      const std::vector<Literal>& arguments);
   LocalShapedBuffer* ExecuteWithShapedBuffers(
       tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
 
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 40089b8..8b4779a 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -188,6 +188,15 @@
   }
 }
 
+%typemap(out) Status {
+  if (!$1.ok()) {
+    PyErr_SetString(
+        PyExc_RuntimeError, $1.ToString().c_str());
+    return NULL;
+  }
+  $result = Py_None;
+}
+
 // ArraySlice<int64>
 
 %typemap(in) tensorflow::gtl::ArraySlice<int64>
@@ -286,6 +295,14 @@
   $result = numpy::PyObjectFromXlaLiteral(*$1);
 }
 
+%typemap(out) StatusOr< std::unique_ptr<Literal> > {
+  if (!$1.ok()) {
+    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
+    return NULL;
+  }
+  $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
+}
+
 %typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
   if (!PySequence_Check($input)) {
     PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
@@ -540,7 +557,10 @@
 %ignoreall
 %unignore xla;
 %unignore xla::swig;
+%unignore xla::swig::InitializeReplicaCount;
+%unignore xla::swig::GetReplicaCount;
 %unignore xla::swig::TransferToInfeedLocal;
+%unignore xla::swig::TransferToInfeedLocalReplica;
 %unignore xla::swig::LocalShapedBuffer;
 %unignore xla::swig::LocalShapedBuffer::FromLiteral;
 %unignore xla::swig::LocalShapedBuffer::ToLiteral;
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 96e2401..fead7d6 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -209,20 +209,24 @@
     return np.require(value, requirements=['C', 'A'])
 
 
-def transfer_to_infeed(value):
+def transfer_to_infeed(value, replica_number=None):
   """Transfers the given value into the XLA infeed queue.
 
   XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with
   a totally ordered stream of values. This is dequeued from XLA computations via
   the Infeed() operation.
 
-  TODO(leary): this currently implicitly enqueues to device ordinal 0.
-
   Args:
     value: the value that the caller would like to enqueue into the XLA infeed
       queue
+    replica_number: the replica number to infeed the value to -- if not
+      provided, then the default replica (trivially replica 0) is used.
   """
-  c_api.TransferToInfeedLocal(require_numpy_array_layout(value))
+  if replica_number is None:
+    c_api.TransferToInfeedLocal(require_numpy_array_layout(value))
+  else:
+    c_api.TransferToInfeedLocalReplica(
+        require_numpy_array_layout(value), replica_number)
 
 
 class LocalComputation(object):
@@ -832,3 +836,25 @@
 
 
 _forward_methods_to_local_builder()
+
+
+def initialize_replica_count(replica_count):
+  """Initializes the desired replica count to use on XLA service init.
+
+  Args:
+    replica_count: number of replicas that are desired for set up during XLA
+      initalization.
+
+  Raises:
+    A runtime exception if the XLA service has already been initialized.
+  """
+  c_api.InitializeReplicaCount(replica_count)
+
+
+def get_replica_count():
+  """Returns the current replica count used for the XLA service.
+
+  Note: this will return a value whether the XLA service has been initialized
+  yet or not.
+  """
+  return c_api.GetReplicaCount()
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 173f0e2..ddd75bb 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -467,6 +467,27 @@
     return HandleSign<ReturnT>(sign);
   }
 
+  template <typename NativeT, typename std::enable_if<std::is_floating_point<
+                                  NativeT>::value>::type* = nullptr>
+  Status HandleAtan2(HloInstruction* atan2) {
+    TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2],
+                        ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem,
+                                                      ElementwiseT rhs_elem) {
+                          return std::atan2(lhs_elem, rhs_elem);
+                        }));
+    return Status::OK();
+  }
+
+  template <typename NativeT, typename std::enable_if<!std::is_floating_point<
+                                  NativeT>::value>::type* = nullptr>
+  Status HandleAtan2(HloInstruction* atan2) {
+    return InvalidArgument("Unsupported type for Atan2");
+  }
+
+  Status HandleAtan2(HloInstruction* atan2) override {
+    return HandleAtan2<ElementwiseT>(atan2);
+  }
+
   Status HandleTanh(HloInstruction* tanh) override {
     TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh],
                         ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index c5cab92..ff3dbdd 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -25,6 +25,7 @@
 #include <iosfwd>
 #include <list>
 #include <memory>
+#include <set>
 #include <string>
 #include <tuple>
 #include <unordered_map>
@@ -1451,6 +1452,10 @@
 using ConstHloInstructionMap =
     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
 
+using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
+using ConstHloInstructionSet =
+    std::set<const HloInstruction*, HloPtrComparator>;
+
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 7b3a8ce..204a8bf 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -169,8 +169,16 @@
       std::unique_ptr<ScopedShapedBuffer> scoped_result,
       ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator()));
 
-  return backend().transfer_manager()->TransferLiteralFromDevice(
+  auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice(
       stream.parent(), *scoped_result);
+  if (result_literal.ok()) {
+    VLOG(4) << "Executed binary and got result: "
+            << result_literal.ValueOrDie()->ToString();
+  } else {
+    VLOG(4) << "Executed binary and got status: "
+            << result_literal.status().ToString();
+  }
+  return result_literal;
 }
 
 Backend& HloRunner::backend() {
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 4071b94..d5715aa 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -122,4 +122,10 @@
                          execute_backend_.get(), executor);
 }
 
+StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
+  return backend().computation_placer()->DeviceId(
+      replica_number, /*computation=*/0, options_.number_of_replicas(),
+      /*computation_count=*/1);
+}
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 52c4346..acbc726 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -47,6 +47,13 @@
       const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
       const Shape* result_layout, int device_ordinal);
 
+  // Returns the device ordinal that corresponds to the given replica number.
+  //
+  // This returns an error if there is not a one-to-one correspondence of
+  // replicas to device ordinals, but is useful as a short term mechanism for
+  // the "easy" case where a single replica is a single device.
+  StatusOr<int> ReplicaNumberToDeviceOrdinal(int replica_number);
+
  private:
   explicit LocalService(const ServiceOptions& options,
                         std::unique_ptr<Backend> backend);
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index 5ca6ccb..c3743b1 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -199,12 +199,10 @@
   StatusOr<const LogicalBuffer*> GetBufferDefinedAt(
       const HloInstruction* instruction, const ShapeIndex& index) const;
 
-  // Return a vector containing all BufferAliases of the given logical buffer
-  // This trivially includes the BufferAlias with same instruction and index as
-  // the logical buffer itself, so the returned vector is never empty.  The
-  // buffer alias set is the inverse of the points-to set. That is,
-  // LogicalBuffer B is in the points-to set of instruction I at index N iff
-  // instruction I, index N is a BufferAlias of B.
+  // Return a (possibly empty) vector containing all BufferAliases of the given
+  // logical buffer The buffer alias set is the inverse of the points-to set.
+  // That is, LogicalBuffer B is in the points-to set of instruction I at index
+  // N iff instruction I, index N is a BufferAlias of B.
   using BufferAliasVector = tensorflow::gtl::InlinedVector<BufferAlias, 1>;
   const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const;
 
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index fb0e6f7..1073cc7 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -306,6 +306,11 @@
     return false;
   }
 
+  if (while_body_root->opcode() != HloOpcode::kTuple) {
+    VLOG(2) << "While body's root is not a tuple(...) instruction.";
+    return false;
+  }
+
   auto print_no_metadata = HloPrintOptions().set_print_metadata(false);
 
   // Bail if param0 of while_cond or while_body has users which aren't of type
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index d99b31d..c5183f8 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -418,5 +418,32 @@
                      op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1)));
 }
 
+TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) {
+  auto scalar_s32 = ShapeUtil::MakeShape(S32, {});
+  Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32});
+
+  HloComputation* while_body = [&]() {
+    HloComputation::Builder builder(TestName() + ".passthrough");
+    HloInstruction* param = builder.AddInstruction(
+        HloInstruction::CreateParameter(0, while_shape, "param"));
+    HloComputation* result = module().AddEmbeddedComputation(builder.Build());
+
+    result->AddInstruction(
+        HloInstruction::CreateGetTupleElement(scalar_s32, param, 1));
+    return result;
+  }();
+
+  HloComputation::Builder builder(TestName());
+  auto* init_value = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, while_shape, "init_value"));
+  builder.AddInstruction(HloInstruction::CreateWhile(
+      while_shape, MakeAlwaysTrueComputation(while_shape, &module()),
+      while_body, init_value));
+  module().AddEntryComputation(builder.Build());
+  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
+                          WhileLoopSimplifier{}.Run(&module()));
+  EXPECT_FALSE(simplified_loop);
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index c6e8b24..935b94c 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -1971,6 +1971,18 @@
                              error_spec_);
 }
 
+XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) {
+  ComputationBuilder builder(client_, TestName());
+  auto a = builder.ConstantR1<float>({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f});
+  auto b = builder.ConstantR1<float>({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f});
+  auto atan = builder.Atan2(a, b);
+
+  ComputeAndCompareR1<float>(
+      &builder,
+      {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f},
+      {}, error_spec_);
+}
+
 XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) {
   ComputationBuilder builder(client_, TestName());
   auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index e8a05cf..bb215be 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -29,7 +29,7 @@
   CHECK_EQ(literal->shape().element_type(),
            primitive_util::NativeToPrimitiveType<FloatT>());
   std::minstd_rand0 engine;
-  std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
+  std::uniform_real_distribution<FloatT> generator(-0.9f, 1.0f);
   TF_CHECK_OK(literal->Populate<FloatT>(
       [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
         return generator(engine);
@@ -42,7 +42,7 @@
 void PopulateWithRandomFloatingPointData<bfloat16>(Literal* literal) {
   CHECK_EQ(literal->shape().element_type(), BF16);
   std::minstd_rand0 engine;
-  std::uniform_real_distribution<float> generator(0.0f, 1.0f);
+  std::uniform_real_distribution<float> generator(-0.9f, 1.0f);
   TF_CHECK_OK(literal->Populate<bfloat16>(
       [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
         return static_cast<bfloat16>(generator(engine));
@@ -126,6 +126,11 @@
                                 fused_uses.end());
       } else if (NeedsZeroInitValue(use)) {
         constrained_uses.push_back(instruction);
+      } else if (opcode == HloOpcode::kConvert ||
+                 opcode == HloOpcode::kReducePrecision) {
+        auto converted_uses = FindConstrainedUses(dataflow, *instruction);
+        constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
+                                converted_uses.end());
       }
     }
   }
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 742835c..c9a3537 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -83,8 +83,11 @@
       return Status::OK();
     }
     AttrValueMap attr_values = func_.attr();
-    return lib_->Instantiate(func_.name(), AttrSlice(&attr_values),
-                             {target_device_}, &handle_);
+    AttrValue v;
+    v.set_s(target_device_);
+    AddAttr("_target", v, &attr_values);
+
+    return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_);
   }
 
   // Returns true if we've got to the end of the sequence and exhausted the
diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py
index bfc7274..15b0820 100644
--- a/tensorflow/contrib/distributions/python/ops/test_util.py
+++ b/tensorflow/contrib/distributions/python/ops/test_util.py
@@ -327,7 +327,7 @@
       num_samples=int(1e5),
       seed=24,
       rtol=1e-2,
-      atol=0.,
+      atol=0.1,
       cov_rtol=None,
       cov_atol=None):
     """Tests that sample/mean/covariance are consistent with each other.
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 6aef010..15a2188 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -6,6 +6,7 @@
 py_library(
     name = "examples_pip",
     deps = [
+        "//tensorflow/contrib/eager/python/examples/gan:mnist",
         "//tensorflow/contrib/eager/python/examples/linear_regression",
         "//tensorflow/contrib/eager/python/examples/mnist",
         "//tensorflow/contrib/eager/python/examples/resnet50",
diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD
new file mode 100644
index 0000000..c61ec2d
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/BUILD
@@ -0,0 +1,36 @@
+licenses(["notice"])  # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+py_binary(
+    name = "mnist",
+    srcs = ["mnist.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/contrib/eager/python:tfe",
+        "//tensorflow/examples/tutorials/mnist:input_data",
+    ],
+)
+
+cuda_py_test(
+    name = "mnist_test",
+    srcs = ["mnist_test.py"],
+    additional_deps = [
+        ":mnist",
+        "//tensorflow/contrib/eager/python:tfe",
+        "//tensorflow:tensorflow_py",
+    ],
+)
+
+cuda_py_test(
+    name = "mnist_graph_test",
+    srcs = ["mnist_graph_test.py"],
+    additional_deps = [
+        ":mnist",
+        "//third_party/py/numpy",
+        "//tensorflow:tensorflow_py",
+    ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/gan/README.md b/tensorflow/contrib/eager/python/examples/gan/README.md
new file mode 100644
index 0000000..e8c9db1
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/README.md
@@ -0,0 +1,38 @@
+# GAN with TensorFlow eager execution
+
+A simple Generative Adversarial Network (GAN) example using eager execution.
+The discriminator and generator networks each contain a few convolution and
+fully connected layers.
+
+Other eager execution examples can be found under the parent directory.
+
+##  Content
+
+- `mnist.py`: Model definitions and training routines.
+- `mnist_test.py`: Benchmarks for training and using the models using eager
+execution.
+- `mnist_graph_test.py`: Benchmarks for trainig and using the models using
+graph execution. The same model definitions and loss functions are used in
+all benchmarks.
+
+
+## To run
+
+- Make sure you have installed TensorFlow 1.5+ or the latest `tf-nightly`
+or `tf-nightly-gpu` pip package in order to access the eager execution feature.
+
+- Train model. E.g.,
+
+  ```bash
+  python mnist.py
+  ```
+  
+  Use `--output_dir=<DIR>` to direct the script to save TensorBoard summaries
+  during training. Disabled by default.
+  
+  Use `--checkpoint_dir=<DIR>` to direct the script to save checkpoints to
+  `<DIR>` during training. DIR defaults to /tmp/tensorflow/mnist/checkpoints/.
+  The script will load the   latest saved checkpoint from this directory if
+  one exists.
+  
+  Use `-h` for other options.
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
new file mode 100644
index 0000000..b9ac79f
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -0,0 +1,368 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A deep MNIST classifier using convolutional layers.
+
+Sample usage:
+  python mnist.py --help
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+import time
+
+import tensorflow as tf
+
+import tensorflow.contrib.eager as tfe
+from tensorflow.examples.tutorials.mnist import input_data
+
+FLAGS = None
+
+
+class Discriminator(tfe.Network):
+  """GAN Discriminator.
+
+  A network to differentiate between generated and real handwritten digits.
+  """
+
+  def __init__(self, data_format):
+    """Creates a model for discriminating between real and generated digits.
+
+    Args:
+      data_format: Either 'channels_first' or 'channels_last'.
+        'channels_first' is typically faster on GPUs while 'channels_last' is
+        typically faster on CPUs. See
+        https://www.tensorflow.org/performance/performance_guide#data_formats
+    """
+    super(Discriminator, self).__init__(name='')
+    if data_format == 'channels_first':
+      self._input_shape = [-1, 1, 28, 28]
+    else:
+      assert data_format == 'channels_last'
+      self._input_shape = [-1, 28, 28, 1]
+    self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME',
+                                                   data_format=data_format,
+                                                   activation=tf.tanh))
+    self.pool1 = self.track_layer(
+        tf.layers.AveragePooling2D(2, 2, data_format=data_format))
+    self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5,
+                                                   data_format=data_format,
+                                                   activation=tf.tanh))
+    self.pool2 = self.track_layer(
+        tf.layers.AveragePooling2D(2, 2, data_format=data_format))
+    self.flatten = self.track_layer(tf.layers.Flatten())
+    self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh))
+    self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None))
+
+  def call(self, inputs):
+    """Return two logits per image estimating input authenticity.
+
+    Users should invoke __call__ to run the network, which delegates to this
+    method (and not call this method directly).
+
+    Args:
+      inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1]
+        or [batch_size, 1, 28, 28]
+
+    Returns:
+      A Tensor with shape [batch_size] containing logits estimating
+      the probability that corresponding digit is real.
+    """
+    x = tf.reshape(inputs, self._input_shape)
+    x = self.conv1(x)
+    x = self.pool1(x)
+    x = self.conv2(x)
+    x = self.pool2(x)
+    x = self.flatten(x)
+    x = self.fc1(x)
+    x = self.fc2(x)
+    return x
+
+
+class Generator(tfe.Network):
+  """Generator of handwritten digits similar to the ones in the MNIST dataset.
+  """
+
+  def __init__(self, data_format):
+    """Creates a model for discriminating between real and generated digits.
+
+    Args:
+      data_format: Either 'channels_first' or 'channels_last'.
+        'channels_first' is typically faster on GPUs while 'channels_last' is
+        typically faster on CPUs. See
+        https://www.tensorflow.org/performance/performance_guide#data_formats
+    """
+    super(Generator, self).__init__(name='')
+    self.data_format = data_format
+    # We are using 128 6x6 channels as input to the first deconvolution layer
+    if data_format == 'channels_first':
+      self._pre_conv_shape = [-1, 128, 6, 6]
+    else:
+      assert data_format == 'channels_last'
+      self._pre_conv_shape = [-1, 6, 6, 128]
+    self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128,
+                                                activation=tf.tanh))
+
+    # In call(), we reshape the output of fc1 to _pre_conv_shape
+
+    # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
+    self.conv1 = self.track_layer(tf.layers.Conv2DTranspose(
+        64, 4, strides=2, activation=None, data_format=data_format))
+
+    # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
+    self.conv2 = self.track_layer(tf.layers.Conv2DTranspose(
+        1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format))
+
+  def call(self, inputs):
+    """Return a batch of generated images.
+
+    Users should invoke __call__ to run the network, which delegates to this
+    method (and not call this method directly).
+
+    Args:
+      inputs: A batch of noise vectors as a Tensor with shape
+        [batch_size, length of noise vectors].
+
+    Returns:
+      A Tensor containing generated images. If data_format is 'channels_last',
+      the shape of returned images is [batch_size, 28, 28, 1], else
+      [batch_size, 1, 28, 28]
+    """
+
+    x = self.fc1(inputs)
+    x = tf.reshape(x, shape=self._pre_conv_shape)
+    x = self.conv1(x)
+    x = self.conv2(x)
+    return x
+
+
+def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
+  """Original discriminator loss for GANs, with label smoothing.
+
+  See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
+  details.
+
+  Args:
+    discriminator_real_outputs: Discriminator output on real data.
+    discriminator_gen_outputs: Discriminator output on generated data. Expected
+      to be in the range of (-inf, inf).
+
+  Returns:
+    A scalar loss Tensor.
+  """
+
+  loss_on_real = tf.losses.sigmoid_cross_entropy(
+      tf.ones_like(discriminator_real_outputs), discriminator_real_outputs,
+      label_smoothing=0.25)
+  loss_on_generated = tf.losses.sigmoid_cross_entropy(
+      tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
+  loss = loss_on_real + loss_on_generated
+  tf.contrib.summary.scalar('discriminator_loss', loss)
+  return loss
+
+
+def generator_loss(discriminator_gen_outputs):
+  """Original generator loss for GANs.
+
+  L = -log(sigmoid(D(G(z))))
+
+  See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661)
+  for more details.
+
+  Args:
+    discriminator_gen_outputs: Discriminator output on generated data. Expected
+      to be in the range of (-inf, inf).
+
+  Returns:
+    A scalar loss Tensor.
+  """
+  loss = tf.losses.sigmoid_cross_entropy(
+      tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
+  tf.contrib.summary.scalar('generator_loss', loss)
+  return loss
+
+
+def train_one_epoch(generator, discriminator,
+                    generator_optimizer, discriminator_optimizer,
+                    dataset, log_interval, noise_dim):
+  """Trains `generator` and `discriminator` models on `dataset`.
+
+  Args:
+    generator: Generator model.
+    discriminator: Discriminator model.
+    generator_optimizer: Optimizer to use for generator.
+    discriminator_optimizer: Optimizer to use for discriminator.
+    dataset: Dataset of images to train on.
+    log_interval: How many global steps to wait between logging and collecting
+      summaries.
+    noise_dim: Dimension of noise vector to use.
+  """
+
+  total_generator_loss = 0.0
+  total_discriminator_loss = 0.0
+  for (batch_index, images) in enumerate(tfe.Iterator(dataset)):
+    with tf.device('/cpu:0'):
+      tf.assign_add(tf.train.get_global_step(), 1)
+
+    with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
+      current_batch_size = images.shape[0]
+      noise = tf.random_uniform(shape=[current_batch_size, noise_dim],
+                                minval=-1., maxval=1., seed=batch_index)
+
+      with tfe.GradientTape(persistent=True) as g:
+        generated_images = generator(noise)
+        tf.contrib.summary.image('generated_images',
+                                 tf.reshape(generated_images, [-1, 28, 28, 1]),
+                                 max_images=10)
+
+        discriminator_gen_outputs = discriminator(generated_images)
+        discriminator_real_outputs = discriminator(images)
+        discriminator_loss_val = discriminator_loss(discriminator_real_outputs,
+                                                    discriminator_gen_outputs)
+        total_discriminator_loss += discriminator_loss_val
+
+        generator_loss_val = generator_loss(discriminator_gen_outputs)
+        total_generator_loss += generator_loss_val
+
+      generator_grad = g.gradient(generator_loss_val, generator.variables)
+      discriminator_grad = g.gradient(discriminator_loss_val,
+                                      discriminator.variables)
+
+      with tf.variable_scope('generator'):
+        generator_optimizer.apply_gradients(zip(generator_grad,
+                                                generator.variables))
+      with tf.variable_scope('discriminator'):
+        discriminator_optimizer.apply_gradients(zip(discriminator_grad,
+                                                    discriminator.variables))
+
+      if log_interval and batch_index > 0 and batch_index % log_interval == 0:
+        print('Batch #%d\tAverage Generator Loss: %.6f\t'
+              'Average Discriminator Loss: %.6f' % (
+                  batch_index, total_generator_loss/batch_index,
+                  total_discriminator_loss/batch_index))
+
+
+def main(_):
+  (device, data_format) = ('/gpu:0', 'channels_first')
+  if FLAGS.no_gpu or tfe.num_gpus() <= 0:
+    (device, data_format) = ('/cpu:0', 'channels_last')
+  print('Using device %s, and data format %s.' % (device, data_format))
+
+  # Load the datasets
+  data = input_data.read_data_sets(FLAGS.data_dir)
+  dataset = (tf.data.Dataset
+             .from_tensor_slices(data.train.images)
+             .shuffle(60000)
+             .batch(FLAGS.batch_size))
+
+  # Create the models and optimizers
+  generator = Generator(data_format)
+  discriminator = Discriminator(data_format)
+  with tf.variable_scope('generator'):
+    generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
+  with tf.variable_scope('discriminator'):
+    discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
+
+  # Prepare summary writer and checkpoint info
+  summary_writer = tf.contrib.summary.create_summary_file_writer(
+      FLAGS.output_dir, flush_millis=1000)
+  checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
+  latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
+  if latest_cpkt:
+    print('Using latest checkpoint at ' + latest_cpkt)
+
+  with tf.device(device):
+    for epoch in range(1, 101):
+      with tfe.restore_variables_on_create(latest_cpkt):
+        global_step = tf.train.get_or_create_global_step()
+        start = time.time()
+        with summary_writer.as_default():
+          train_one_epoch(generator, discriminator, generator_optimizer,
+                          discriminator_optimizer,
+                          dataset, FLAGS.log_interval, FLAGS.noise)
+        end = time.time()
+        print('\nTrain time for epoch #%d (global step %d): %f' % (
+            epoch, global_step.numpy(), end - start))
+
+      all_variables = (
+          generator.variables
+          + discriminator.variables
+          + generator_optimizer.variables()
+          + discriminator_optimizer.variables()
+          + [global_step])
+      tfe.Saver(all_variables).save(
+          checkpoint_prefix, global_step=global_step)
+
+
+if __name__ == '__main__':
+  tfe.enable_eager_execution()
+
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--data-dir',
+      type=str,
+      default='/tmp/tensorflow/mnist/input_data',
+      help=('Directory for storing input data (default '
+            '/tmp/tensorflow/mnist/input_data)'))
+  parser.add_argument(
+      '--batch-size',
+      type=int,
+      default=128,
+      metavar='N',
+      help='input batch size for training (default: 128)')
+  parser.add_argument(
+      '--log-interval',
+      type=int,
+      default=100,
+      metavar='N',
+      help=('number of batches between logging and writing summaries '
+            '(default: 100)'))
+  parser.add_argument(
+      '--output_dir',
+      type=str,
+      default=None,
+      metavar='DIR',
+      help='Directory to write TensorBoard summaries (defaults to none)')
+  parser.add_argument(
+      '--checkpoint_dir',
+      type=str,
+      default='/tmp/tensorflow/mnist/checkpoints/',
+      metavar='DIR',
+      help=('Directory to save checkpoints in (once per epoch) (default '
+            '/tmp/tensorflow/mnist/checkpoints/)'))
+  parser.add_argument(
+      '--lr',
+      type=float,
+      default=0.001,
+      metavar='LR',
+      help='learning rate (default: 0.001)')
+  parser.add_argument(
+      '--noise',
+      type=int,
+      default=100,
+      metavar='N',
+      help='Length of noise vector for generator input (default: 100)')
+  parser.add_argument(
+      '--no-gpu',
+      action='store_true',
+      default=False,
+      help='disables GPU usage even if a GPU is available')
+
+  FLAGS, unparsed = parser.parse_known_args()
+  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py
new file mode 100644
index 0000000..12b39b0
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py
@@ -0,0 +1,151 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.eager.python.examples.gan import mnist
+
+NOISE_DIM = 100
+# Big enough so that summaries are never recorded.
+# Lower this value if would like to benchmark with some summaries.
+SUMMARY_INTERVAL = 10000
+SUMMARY_FLUSH_MS = 100  # Flush summaries every 100ms
+
+
+def data_format():
+  return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
+
+
+class MnistGraphGanBenchmark(tf.test.Benchmark):
+
+  def _create_graph(self, batch_size):
+    # Generate some random data.
+    images_data = np.random.randn(batch_size, 784).astype(np.float32)
+    dataset = tf.data.Dataset.from_tensors(images_data)
+    images = dataset.repeat().make_one_shot_iterator().get_next()
+
+    # Create the models and optimizers
+    generator = mnist.Generator(data_format())
+    discriminator = mnist.Discriminator(data_format())
+    with tf.variable_scope('generator'):
+      generator_optimizer = tf.train.AdamOptimizer(0.001)
+    with tf.variable_scope('discriminator'):
+      discriminator_optimizer = tf.train.AdamOptimizer(0.001)
+
+    # Run models and compute loss
+    noise_placeholder = tf.placeholder(tf.float32,
+                                       shape=[batch_size, NOISE_DIM])
+    generated_images = generator(noise_placeholder)
+    tf.contrib.summary.image('generated_images',
+                             tf.reshape(generated_images, [-1, 28, 28, 1]),
+                             max_images=10)
+    discriminator_gen_outputs = discriminator(generated_images)
+    discriminator_real_outputs = discriminator(images)
+    generator_loss = mnist.generator_loss(discriminator_gen_outputs)
+    discriminator_loss = mnist.discriminator_loss(discriminator_real_outputs,
+                                                  discriminator_gen_outputs)
+    # Get train ops
+    with tf.variable_scope('generator'):
+      generator_train = generator_optimizer.minimize(
+          generator_loss, var_list=generator.variables)
+    with tf.variable_scope('discriminator'):
+      discriminator_train = discriminator_optimizer.minimize(
+          discriminator_loss, var_list=discriminator.variables)
+
+    return (generator_train, discriminator_train, noise_placeholder)
+
+  def _report(self, test_name, start, num_iters, batch_size):
+    avg_time = (time.time() - start) / num_iters
+    dev = 'gpu' if tf.test.is_gpu_available() else 'cpu'
+    name = 'graph_%s_%s_batch_%d_%s' % (test_name, dev, batch_size,
+                                        data_format())
+    extras = {'examples_per_sec': batch_size / avg_time}
+    self.report_benchmark(
+        iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+  def benchmark_train(self):
+    for batch_size in [64, 128, 256]:
+      with tf.Graph().as_default():
+        global_step = tf.train.get_or_create_global_step()
+        increment_global_step = tf.assign_add(global_step, 1)
+        with tf.contrib.summary.create_file_writer(
+            tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(), (
+                tf.contrib.summary.record_summaries_every_n_global_steps(
+                    SUMMARY_INTERVAL)):
+          (generator_train, discriminator_train, noise_placeholder
+          ) = self._create_graph(batch_size)
+
+          with tf.Session() as sess:
+            tf.contrib.summary.initialize(graph=tf.get_default_graph(),
+                                          session=sess)
+
+            sess.run(tf.global_variables_initializer())
+
+            num_burn, num_iters = (3, 100)
+            for _ in range(num_burn):
+              noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+              # Increment global step before evaluating summary ops to avoid
+              # race condition.
+              sess.run(increment_global_step)
+              sess.run([generator_train, discriminator_train,
+                        tf.contrib.summary.all_summary_ops()],
+                       feed_dict={noise_placeholder: noise})
+
+            # Run and benchmark 2 epochs
+            start = time.time()
+            for _ in range(num_iters):
+              noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+              sess.run(increment_global_step)
+              sess.run([generator_train, discriminator_train,
+                        tf.contrib.summary.all_summary_ops()],
+                       feed_dict={noise_placeholder: noise})
+            self._report('train', start, num_iters, batch_size)
+
+  def benchmark_generate(self):
+    for batch_size in [64, 128, 256]:
+      with tf.Graph().as_default():
+        # Using random weights. This will generate garbage.
+        generator = mnist.Generator(data_format())
+        noise_placeholder = tf.placeholder(tf.float32,
+                                           shape=[batch_size, NOISE_DIM])
+        generated_images = generator(noise_placeholder)
+
+        init = tf.global_variables_initializer()
+        with tf.Session() as sess:
+          sess.run(init)
+          noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+          num_burn, num_iters = (30, 1000)
+          for _ in range(num_burn):
+            sess.run(generated_images, feed_dict={noise_placeholder: noise})
+
+          start = time.time()
+          for _ in range(num_iters):
+            # Comparison with the eager execution benchmark in mnist_test.py
+            # isn't entirely fair as the time here includes the cost of copying
+            # the feeds from CPU memory to GPU.
+            sess.run(generated_images, feed_dict={noise_placeholder: noise})
+          self._report('generate', start, num_iters, batch_size)
+
+
+if __name__ == '__main__':
+  tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py
new file mode 100644
index 0000000..4a3ca8d
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py
@@ -0,0 +1,113 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import time
+
+import tensorflow as tf
+
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.gan import mnist
+
+NOISE_DIM = 100
+# Big enough so that summaries are never recorded.
+# Lower this value if would like to benchmark with some summaries.
+SUMMARY_INTERVAL = 10000
+SUMMARY_FLUSH_MS = 100  # Flush summaries every 100ms
+
+
+def data_format():
+  return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
+
+
+def device():
+  return '/gpu:0' if tfe.num_gpus() else '/cpu:0'
+
+
+class MnistEagerGanBenchmark(tf.test.Benchmark):
+
+  def _report(self, test_name, start, num_iters, batch_size):
+    avg_time = (time.time() - start) / num_iters
+    dev = 'gpu' if tfe.num_gpus() else 'cpu'
+    name = 'eager_%s_%s_batch_%d_%s' % (test_name, dev, batch_size,
+                                        data_format())
+    extras = {'examples_per_sec': batch_size / avg_time}
+    self.report_benchmark(
+        iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+  def benchmark_train(self):
+    for batch_size in [64, 128, 256]:
+      # Generate some random data.
+      burn_batches, measure_batches = (3, 100)
+      burn_images = [tf.random_normal([batch_size, 784])
+                     for _ in range(burn_batches)]
+      burn_dataset = tf.data.Dataset.from_tensor_slices(burn_images)
+      measure_images = [tf.random_normal([batch_size, 784])
+                        for _ in range(measure_batches)]
+      measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images)
+
+      tf.train.get_or_create_global_step()
+      with tf.device(device()):
+        # Create the models and optimizers
+        generator = mnist.Generator(data_format())
+        discriminator = mnist.Discriminator(data_format())
+        with tf.variable_scope('generator'):
+          generator_optimizer = tf.train.AdamOptimizer(0.001)
+        with tf.variable_scope('discriminator'):
+          discriminator_optimizer = tf.train.AdamOptimizer(0.001)
+
+        with tf.contrib.summary.create_file_writer(
+            tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default():
+
+          # warm up
+          mnist.train_one_epoch(generator, discriminator, generator_optimizer,
+                                discriminator_optimizer,
+                                burn_dataset, log_interval=SUMMARY_INTERVAL,
+                                noise_dim=NOISE_DIM)
+          # measure
+          start = time.time()
+          mnist.train_one_epoch(generator, discriminator, generator_optimizer,
+                                discriminator_optimizer,
+                                measure_dataset, log_interval=SUMMARY_INTERVAL,
+                                noise_dim=NOISE_DIM)
+          self._report('train', start, measure_batches, batch_size)
+
+  def benchmark_generate(self):
+    for batch_size in [64, 128, 256]:
+      with tf.device(device()):
+        # Using random weights. This will generate garbage.
+        generator = mnist.Generator(data_format())
+
+        num_burn, num_iters = (30, 1000)
+        for _ in range(num_burn):
+          noise = tf.random_uniform(shape=[batch_size, NOISE_DIM],
+                                    minval=-1., maxval=1.)
+          generator(noise)
+
+        start = time.time()
+        for _ in range(num_iters):
+          noise = tf.random_uniform(shape=[batch_size, NOISE_DIM],
+                                    minval=-1., maxval=1.)
+          generator(noise)
+        self._report('generate', start, num_iters, batch_size)
+
+
+if __name__ == '__main__':
+  tfe.enable_eager_execution()
+  tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
index 205709f..136085e 100644
--- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
+++ b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py
@@ -39,22 +39,40 @@
   return tf.data.Dataset.from_tensors((images, labels))
 
 
+def train_one_epoch(defun=False):
+  model = mnist.MNISTModel(data_format())
+  if defun:
+    model.call = tfe.defun(model.call)
+  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
+  dataset = random_dataset()
+  with tf.device(device()):
+    tf.train.get_or_create_global_step()
+    mnist.train_one_epoch(model, optimizer, dataset)
+
+
+def evaluate(defun=False):
+  model = mnist.MNISTModel(data_format())
+  dataset = random_dataset()
+  if defun:
+    model.call = tfe.defun(model.call)
+  with tf.device(device()):
+    tf.train.get_or_create_global_step()
+    mnist.test(model, dataset)
+
+
 class MNISTTest(tf.test.TestCase):
 
   def testTrainOneEpoch(self):
-    model = mnist.MNISTModel(data_format())
-    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
-    dataset = random_dataset()
-    with tf.device(device()):
-      tf.train.get_or_create_global_step()
-      mnist.train_one_epoch(model, optimizer, dataset)
+    train_one_epoch(defun=False)
 
   def testTest(self):
-    model = mnist.MNISTModel(data_format())
-    dataset = random_dataset()
-    with tf.device(device()):
-      tf.train.get_or_create_global_step()
-      mnist.test(model, dataset)
+    evaluate(defun=False)
+
+  def testTrainOneEpochWithDefunCall(self):
+    train_one_epoch(defun=True)
+
+  def testTestWithDefunCall(self):
+    evaluate(defun=True)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index d8d8644..932f95c 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -64,14 +64,22 @@
 
 class ResNet50Test(tf.test.TestCase):
 
-  def test_apply(self):
+  def _apply(self, defun=False):
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format)
+    if defun:
+      model.call = tfe.defun(model.call)
     with tf.device(device):
       images, _ = random_batch(2)
       output = model(images)
     self.assertEqual((2, 1000), output.shape)
 
+  def test_apply(self):
+    self._apply(defun=False)
+
+  def test_apply_with_defun(self):
+    self._apply(defun=True)
+
   def test_apply_no_top(self):
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format, include_top=False)
@@ -175,9 +183,11 @@
     # a sync. This is a roundabout way, yes.
     tf.constant(1.).cpu()
 
-  def benchmark_eager_apply(self):
+  def _benchmark_eager_apply(self, label, defun=False):
     device, data_format = device_and_data_format()
     model = resnet50.ResNet50(data_format)
+    if defun:
+      model.call = tfe.defun(model.call)
     batch_size = 64
     num_burn = 5
     num_iters = 30
@@ -189,16 +199,23 @@
       start = time.time()
       for _ in xrange(num_iters):
         model(images).cpu()
-      self._report('eager_apply', start, num_iters, device, batch_size,
-                   data_format)
+      self._report(label, start, num_iters, device, batch_size, data_format)
 
-  def _benchmark_eager_train(self, label, make_iterator):
+  def benchmark_eager_apply(self):
+    self._benchmark_eager_apply('eager_apply', defun=False)
+
+  def benchmark_eager_apply_with_defun(self):
+    self._benchmark_eager_apply('eager_apply_with_defun', defun=True)
+
+  def _benchmark_eager_train(self, label, make_iterator, defun=False):
     device, data_format = device_and_data_format()
     for batch_size in self._train_batch_sizes():
       (images, labels) = random_batch(batch_size)
       num_burn = 3
       num_iters = 10
       model = resnet50.ResNet50(data_format)
+      if defun:
+        model.call = tfe.defun(model.call)
       optimizer = tf.train.GradientDescentOptimizer(0.1)
 
       with tf.device(device):
@@ -217,7 +234,11 @@
         self._report(label, start, num_iters, device, batch_size, data_format)
 
   def benchmark_eager_train(self):
-    self._benchmark_eager_train('eager_train', MockIterator)
+    self._benchmark_eager_train('eager_train', MockIterator, defun=False)
+
+  def benchmark_eager_train_with_defun(self):
+    self._benchmark_eager_train(
+        'eager_train_with_defun', MockIterator, defun=True)
 
   def benchmark_eager_train_datasets(self):
 
@@ -226,7 +247,18 @@
         ds = tf.data.Dataset.from_tensors(tensors).repeat()
       return tfe.Iterator(ds)
 
-    self._benchmark_eager_train('eager_train_dataset', make_iterator)
+    self._benchmark_eager_train(
+        'eager_train_dataset', make_iterator, defun=False)
+
+  def benchmark_eager_train_datasets_with_defun(self):
+
+    def make_iterator(tensors):
+      with tf.device('/device:CPU:0'):
+        ds = tf.data.Dataset.from_tensors(tensors).repeat()
+      return tfe.Iterator(ds)
+
+    self._benchmark_eager_train(
+        'eager_train_dataset', make_iterator, defun=True)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index 3eb4f5f..8e6b947 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -105,15 +105,13 @@
     result = net(constant_op.constant([[2.0]]))
     self.assertEqual(34.0, self.evaluate(result))
 
-  # TODO(akshayka): This test should be changed once an API for compiling
-  # `call` into a defun is implemented.
   def testReplacingNetworkCallWithDefun(self):
     net = MyNetwork(name="abcd")
+    net.call = function.defun(net.call)
     x = constant_op.constant([[2.0]])
     net(x)  # Force variables to be created.
     self.evaluate(net.trainable_variables[0].assign([[17.0]]))
 
-    net.call = function.defun(net.call)
     result = net(x)  # Build and execute the TensorFlow function
     self.assertEqual(34.0, self.evaluate(result))
 
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 1150328..a9bdbe0 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -3237,7 +3237,11 @@
       images = random_ops.random_uniform((5, height, width, 3), seed=1)
       regularizer = regularizers.l2_regularizer(0.01)
       layers_lib.separable_conv2d(
-          images, 32, [3, 3], 2, weights_regularizer=regularizer)
+          images,
+          32, [3, 3],
+          2,
+          weights_regularizer=regularizer,
+          weights_initializer=init_ops.ones_initializer())
       self.assertEqual(
           len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2)
       weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0]
@@ -3245,12 +3249,31 @@
           weight_decay.op.name,
           'SeparableConv2d/depthwise_kernel/Regularizer/l2_regularizer')
       sess.run(variables_lib.global_variables_initializer())
-      self.assertLessEqual(sess.run(weight_decay), 0.05)
+      depth_weight_one = sess.run(weight_decay)
       weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[1]
       self.assertEqual(
           weight_decay.op.name,
           'SeparableConv2d/pointwise_kernel/Regularizer/l2_regularizer')
-      self.assertLessEqual(sess.run(weight_decay), 0.05)
+      pointwise_weight_one = sess.run(weight_decay)
+
+      regularizer = regularizers.l2_regularizer(1.0)
+      layers_lib.separable_conv2d(
+          images,
+          32, [3, 3],
+          2,
+          weights_regularizer=regularizer,
+          weights_initializer=init_ops.ones_initializer())
+      self.assertEqual(
+          len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 4)
+      weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[2]
+      sess.run(variables_lib.global_variables_initializer())
+      depth_weight_two = sess.run(weight_decay)
+      weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[3]
+      pointwise_weight_two = sess.run(weight_decay)
+
+      self.assertAllClose(
+          [100.0 * depth_weight_one, 100.0 * pointwise_weight_one],
+          [depth_weight_two, pointwise_weight_two])
 
   def testReuseConvWithWeightDecay(self):
     height, width = 3, 3
diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
index 0dafb1f..a885a57 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
@@ -96,19 +96,19 @@
 }
 
 NSString* RunInferenceOnImage() {
-  std::string graph;
+  NSString* graph = @"mobilenet_v1_1.0_224";
   const int num_threads = 1;
   std::string input_layer_type = "float";
   std::vector<int> sizes = {1, 224, 224, 3};
 
-  NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite");
+  const NSString* graph_path = FilePathForResourceName(graph, @"tflite");
 
   std::unique_ptr<tflite::FlatBufferModel> model(
       tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]));
   if (!model) {
-    LOG(FATAL) << "Failed to mmap model " << graph;
+    LOG(FATAL) << "Failed to mmap model " << [graph UTF8String];
   }
-  LOG(INFO) << "Loaded model " << graph;
+  LOG(INFO) << "Loaded model " << [graph UTF8String];
   model->error_reporter();
   LOG(INFO) << "resolved reporter";
 
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index 81567ce..81df517 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -247,7 +247,8 @@
 INSTANTIATE_TESTS(space_to_batch_nd)
 INSTANTIATE_TESTS(batch_to_space_nd)
 INSTANTIATE_TESTS(concat)
-INSTANTIATE_TESTS(constant)
+// TODO(b/71642435) re-enable this test
+// INSTANTIATE_TESTS(constant)
 INSTANTIATE_TESTS(control_dep)
 INSTANTIATE_TESTS(conv)
 INSTANTIATE_TESTS(depthwiseconv)
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 51d76e4..90fa442 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -802,8 +802,10 @@
   *reshape_op->add_input() = src_op.inputs[1];
   (*reshape_op->mutable_attr())["T"].set_type(DT_FLOAT);
   const auto& shape_array = model.GetArray(src_op.inputs[1]);
-  CHECK(shape_array.data_type == ArrayDataType::kInt32);
-  CHECK(shape_array.buffer != nullptr);
+  QCHECK(shape_array.data_type == ArrayDataType::kInt32)
+      << "Only int32 shape is supported.";
+  QCHECK(shape_array.buffer != nullptr)
+      << "Shape inferred at runtime is not supported.";
   const auto& shape_data = shape_array.GetBuffer<ArrayDataType::kInt32>().data;
   CreateReshapeShapeTensorConst(src_op.inputs[1], shape_data, tensorflow_graph);
 }
diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
index 95dfd8f..764ab62 100644
--- a/tensorflow/contrib/model_pruning/python/layers/core_layers.py
+++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py
@@ -210,7 +210,7 @@
       return self.activation(outputs)
     return outputs
 
-  def _compute_output_shape(self, input_shape):
+  def compute_output_shape(self, input_shape):
     input_shape = tensor_shape.TensorShape(input_shape).as_list()
     if self.data_format == 'channels_last':
       space = input_shape[1:-1]
@@ -467,7 +467,7 @@
       return self.activation(outputs)  # pylint: disable=not-callable
     return outputs
 
-  def _compute_output_shape(self, input_shape):
+  def compute_output_shape(self, input_shape):
     input_shape = tensor_shape.TensorShape(input_shape)
     input_shape = input_shape.with_rank_at_least(2)
     if input_shape[-1].value is None:
diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD
index 4322852..dca380c 100644
--- a/tensorflow/contrib/py2tf/pyct/BUILD
+++ b/tensorflow/contrib/py2tf/pyct/BUILD
@@ -22,6 +22,7 @@
         "compiler.py",
         "parser.py",
         "pretty_printer.py",
+        "templates.py",
     ],
     srcs_version = "PY2AND3",
     visibility = ["//visibility:public"],
@@ -30,6 +31,10 @@
 py_test(
     name = "anno_test",
     srcs = ["anno_test.py"],
+    tags = [
+        "manual",
+        "notap",
+    ],
     deps = [
         ":pyct",
         "//tensorflow/python:client_testlib",
@@ -74,3 +79,16 @@
         "//tensorflow/python:client_testlib",
     ],
 )
+
+py_test(
+    name = "templates_test",
+    srcs = ["templates_test.py"],
+    tags = [
+        "manual",
+        "notap",
+    ],
+    deps = [
+        ":pyct",
+        "//tensorflow/python:client_testlib",
+    ],
+)
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py
new file mode 100644
index 0000000..6acc03b
--- /dev/null
+++ b/tensorflow/contrib/py2tf/pyct/templates.py
@@ -0,0 +1,112 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST conversion templates.
+
+Adapted from Tangent.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import parser
+
+
+class ReplaceTransformer(gast.NodeTransformer):
+  """Replace AST nodes."""
+
+  def __init__(self, replacements):
+    """Create a new ReplaceTransformer.
+
+    Args:
+      replacements: A mapping from placeholder names to (lists of) AST nodes
+          that these placeholders will be replaced by.
+    """
+    self.replacements = replacements
+
+  # TODO(mdan): Make a more detailed pass and clean up if needed.
+
+  def visit_Expr(self, node):
+    if (isinstance(node.value, gast.Name) and
+        node.value.id in self.replacements):
+      return self.visit(node.value)
+    self.generic_visit(node)
+    return node
+
+  def visit_FunctionDef(self, node):
+    node = self.generic_visit(node)
+    if node.name in self.replacements:
+      repl = self.replacements[node.name]
+      if not isinstance(repl, (gast.Name, ast.Name)):
+        raise ValueError(
+            'A function name can only be replaced by a Name node. Found: %s',
+            repl)
+      node.name = repl.id
+    return node
+
+  def visit_Name(self, node):
+    # Note: The caller is reposnsible with making sure the replacement
+    # Name nodes have the proper ctx set up.
+    # TODO(mdan): Is it possible to always infer the proper context here?
+    if node.id in self.replacements:
+      # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
+      new_nodes = self.replacements[node.id]
+      if isinstance(new_nodes, gast.AST):
+        new_nodes = [new_nodes]
+      if len(new_nodes) == 1:
+        new_nodes, = new_nodes
+      return new_nodes
+    else:
+      return node
+
+
+def replace(template, **replacements):
+  """Replace placeholders in a Python template.
+
+  Args:
+    template: A function to be used as a template. Any placeholder is expected
+        to also be a function argument.
+    **replacements: A mapping from placeholder names to (lists of) AST nodes
+        that these placeholders will be replaced by.
+
+  Returns:
+    body: An AST node or list of AST nodes with the replacements made. If the
+        template was a function, a list will be returned. If the template was a
+        node, the same node will be returned. If the template was a string, an
+        AST node will be returned (a `Module` node in the case of a multi-line
+        string, an `Expr` node otherwise).
+
+  Raises:
+    ValueError: If a function is used as a template and an incorrect set of
+        replacements was passed.
+  """
+  tree = parser.parse_object(template).body[0]
+  placeholders = set(arg.id for arg in tree.args.args)
+  tree.args.args = []
+  if tree.args.vararg:
+    placeholders.add(tree.args.vararg)
+    tree.args.vararg = None
+  if set(replacements.keys()) != placeholders:
+    raise ValueError(
+        'too many or few replacements. replacements: %s; placeholders: %s' %
+        (replacements.keys(), placeholders))
+
+  # Perform the replacement, stripping the function into which the template was
+  # wrapped.
+  return ReplaceTransformer(replacements).visit(tree).body
diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py
new file mode 100644
index 0000000..2ad8b93
--- /dev/null
+++ b/tensorflow/contrib/py2tf/pyct/templates_test.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for templates module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.python.platform import test
+
+
+class TemplatesTest(test.TestCase):
+
+  def test_replace_variable(self):
+    def template(a):  # pylint:disable=unused-argument
+      def test_fn(a):  # pylint:disable=unused-variable
+        a += 1
+        a = 2 * a + 1
+        return b  # pylint:disable=undefined-variable
+
+    node = templates.replace(
+        template, a=gast.Name('b', gast.Load(), None))[0]
+    result = compiler.ast_to_object(node)
+    self.assertEquals(7, result.test_fn(2))
+
+  def test_replace_function_name(self):
+    def template(fname):  # pylint:disable=unused-argument
+      def fname(a):  # pylint:disable=function-redefined
+        a += 1
+        a = 2 * a + 1
+        return a
+
+    node = templates.replace(
+        template, fname=gast.Name('test_fn', gast.Load(), None))[0]
+    result = compiler.ast_to_object(node)
+    self.assertEquals(7, result.test_fn(2))
+
+  def test_code_block(self):
+    def template(block):  # pylint:disable=unused-argument
+      def test_fn(a):  # pylint:disable=unused-variable
+        block  # pylint:disable=pointless-statement
+        return a
+
+    node = templates.replace(
+        template,
+        block=[
+            gast.Assign(
+                [
+                    gast.Name('a', gast.Store(), None)
+                ],
+                gast.BinOp(
+                    gast.Name('a', gast.Load(), None),
+                    gast.Add(),
+                    gast.Num(1))),
+        ] * 2)[0]
+    result = compiler.ast_to_object(node)
+    self.assertEquals(3, result.test_fn(1))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 46823fa..7378920 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -845,12 +845,14 @@
       batch_size = 3
       input_size = 4
       expected_state_c = np.array(
-          [[0.00072015, 0.00036633], [0.00083481, 0.00047266],
-           [0.00085111, 0.00053054]],
+          [[6.450831e-04, 4.697885e-04],
+           [9.862894e-05, 7.212213e-04],
+           [4.401947e-04, 9.143004e-04]],
           dtype=np.float32)
       expected_state_h = np.array(
-          [[0.0005159, 0.00026243], [0.00062958, 0.00035646],
-           [0.00064732, 0.00040351]],
+          [[4.621217e-04, 3.365449e-04],
+           [7.438179e-05, 5.439147e-04],
+           [3.347936e-04, 6.953785e-04]],
           dtype=np.float32)
       with variable_scope.variable_scope(
           "root", initializer=init_ops.constant_initializer(0.5)):
@@ -1328,7 +1330,7 @@
     h_low = 0.761552567265
     h_high = 0.995008519604
     num_units = 5
-    allowed_low = [2, 3]
+    allowed_low = [1, 2, 3]
 
     with self.test_session() as sess:
       with variable_scope.variable_scope(
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index e5d5917..7465f20 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -69,7 +69,7 @@
   def assertAllCloseOrEqual(self, x, y, **kwargs):
     if isinstance(x, np.ndarray) or isinstance(x, float):
       return super(AttentionWrapperTest, self).assertAllClose(
-          x, y, atol=1e-4, **kwargs)
+          x, y, atol=1e-3, **kwargs)
     else:
       self.assertAllEqual(x, y, **kwargs)
 
@@ -276,7 +276,7 @@
         rnn_output=ResultSummary(
             shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.00597103),
         sample_id=ResultSummary(
-            shape=(5, 3), dtype=dtype('int32'), mean=1.4))
+            shape=(5, 3), dtype=dtype('int32'), mean=1.6))
     expected_final_state = AttentionWrapperState(
         cell_state=LSTMStateTuple(
             c=ResultSummary(
@@ -305,7 +305,7 @@
         rnn_output=ResultSummary(
             shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386),
         sample_id=ResultSummary(
-            shape=(5, 3), dtype=dtype('int32'), mean=1.4666666666666666))
+            shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333))
     expected_final_state = AttentionWrapperState(
         cell_state=LSTMStateTuple(
             c=ResultSummary(
@@ -336,7 +336,7 @@
         rnn_output=ResultSummary(
             shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0052615386),
         sample_id=ResultSummary(
-            shape=(5, 3), dtype=dtype('int32'), mean=1.4666666666666666))
+            shape=(5, 3), dtype=dtype('int32'), mean=1.3333333333333333))
     expected_final_state = AttentionWrapperState(
         cell_state=LSTMStateTuple(
             c=ResultSummary(
@@ -578,7 +578,7 @@
         rnn_output=ResultSummary(
             shape=(5, 3, 6), dtype=dtype('float32'), mean=-0.0025896581),
         sample_id=ResultSummary(
-            shape=(5, 3), dtype=dtype('int32'), mean=1.8666666666666667))
+            shape=(5, 3), dtype=dtype('int32'), mean=1.6))
     expected_final_state = AttentionWrapperState(
         cell_state=LSTMStateTuple(
             c=ResultSummary(
@@ -594,7 +594,7 @@
             shape=(5, 8), dtype=dtype('float32'), mean=0.028698336),
         alignment_history=())
     expected_final_alignment_history = ResultSummary(
-        shape=(3, 5, 8), dtype=dtype('float32'), mean=0.046009291)
+        shape=(3, 5, 8), dtype=dtype('float32'), mean=0.04865776002407074)
 
     self._testWithAttention(
         create_attention_mechanism,
@@ -761,9 +761,9 @@
 
     expected_final_output = BasicDecoderOutput(
         rnn_output=ResultSummary(
-            shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11691988),
+            shape=(5, 3, 20), dtype=dtype('float32'), mean=0.11798714846372604),
         sample_id=ResultSummary(
-            shape=(5, 3), dtype=dtype('int32'), mean=7.2666666666666666))
+            shape=(5, 3), dtype=dtype('int32'), mean=7.933333333333334))
     expected_final_state = AttentionWrapperState(
         cell_state=LSTMStateTuple(
             c=ResultSummary(
@@ -771,7 +771,7 @@
             h=ResultSummary(
                 shape=(5, 9), dtype=dtype('float32'), mean=-0.0018835809)),
         attention=ResultSummary(
-            shape=(5, 20), dtype=dtype('float32'), mean=0.11680689),
+            shape=(5, 20), dtype=dtype('float32'), mean=0.11798714846372604),
         time=3,
         alignments=(
             ResultSummary(shape=(5, 8), dtype=dtype('float32'), mean=0.125),
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index cddd628..a998ac1 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -237,8 +237,7 @@
     if params.inference_tree_paths:
       model_ops.predictions[TREE_PATHS_PREDICTION_KEY] = tree_paths
 
-    if params.regression:
-      model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
+    model_ops.predictions[VARIANCE_PREDICTION_KEY] = regression_variance
 
     return model_ops
 
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index eb93876..3650b5d 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -478,8 +478,7 @@
       **inference_args: Keyword arguments to pass through to each tree.
 
     Returns:
-      A tuple of (probabilities, tree_paths, variance), where variance
-      is the variance over all the trees for regression problems only.
+      A tuple of (probabilities, tree_paths, variance).
 
     Raises:
       NotImplementedError: If trying to use feature bagging with sparse
@@ -513,13 +512,12 @@
           self.params.num_trees,
           name='probabilities')
       tree_paths = array_ops.stack(paths, axis=1)
-      regression_variance = None
-      if self.params.regression:
-        expected_squares = math_ops.div(
-            math_ops.reduce_sum(all_predict * all_predict, 1),
-            self.params.num_trees)
-        regression_variance = math_ops.maximum(
-            0., expected_squares - average_values * average_values)
+
+      expected_squares = math_ops.div(
+          math_ops.reduce_sum(all_predict * all_predict, 1),
+          self.params.num_trees)
+      regression_variance = math_ops.maximum(
+          0., expected_squares - average_values * average_values)
       return average_values, tree_paths, regression_variance
 
   def average_size(self):
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
index 113dfb8..bbe627b 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py
@@ -108,7 +108,7 @@
     probs, paths, var = graph_builder.inference_graph(input_data)
     self.assertTrue(isinstance(probs, ops.Tensor))
     self.assertTrue(isinstance(paths, ops.Tensor))
-    self.assertIsNone(var)
+    self.assertTrue(isinstance(var, ops.Tensor))
 
   def testTrainingConstructionClassificationSparse(self):
     input_data = sparse_tensor.SparseTensor(
diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc
index 1aff789..0514fce 100644
--- a/tensorflow/contrib/tensorboard/db/schema.cc
+++ b/tensorflow/contrib/tensorboard/db/schema.cc
@@ -19,7 +19,8 @@
 
 Status Run(Sqlite* db, const char* sql) {
   auto stmt = db->Prepare(sql);
-  TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt.StepAndReset(), sql);
+  TF_RETURN_IF_ERROR(stmt.status());
+  TF_RETURN_IF_ERROR(stmt.ValueOrDie().StepAndReset());
   return Status::OK();
 }
 
@@ -28,11 +29,11 @@
 Status SetupTensorboardSqliteDb(Sqlite* db) {
   // Note: GCC raw strings macros are broken.
   // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971
-  db->Prepare(strings::StrCat("PRAGMA application_id=",
-                              kTensorboardSqliteApplicationId))
-      .StepAndReset()
-      .IgnoreError();
-  db->Prepare("PRAGMA user_version=0").StepAndReset().IgnoreError();
+  TF_RETURN_IF_ERROR(
+      db->PrepareOrDie(strings::StrCat("PRAGMA application_id=",
+                                       kTensorboardSqliteApplicationId))
+          .StepAndReset());
+  db->PrepareOrDie("PRAGMA user_version=0").StepAndResetOrDie();
   Status s;
 
   // Creates Ids table.
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index fc201be..a9d75e9 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -133,7 +133,8 @@
 class IdAllocator {
  public:
   IdAllocator(Env* env, Sqlite* db)
-      : env_{env}, inserter_{db->Prepare("INSERT INTO Ids (id) VALUES (?)")} {}
+      : env_{env},
+        inserter_{db->PrepareOrDie("INSERT INTO Ids (id) VALUES (?)")} {}
 
   Status CreateNewId(int64* id) {
     Status s;
@@ -208,7 +209,7 @@
   }
 
   Status SaveNodeInputs() {
-    auto insert = db_->Prepare(R"sql(
+    auto insert = db_->PrepareOrDie(R"sql(
       INSERT INTO NodeInputs (graph_id, node_id, idx, input_node_id, is_control)
       VALUES (?, ?, ?, ?, ?)
     )sql");
@@ -236,7 +237,7 @@
   }
 
   Status SaveNodes() {
-    auto insert = db_->Prepare(R"sql(
+    auto insert = db_->PrepareOrDie(R"sql(
       INSERT INTO Nodes (graph_id, node_id, node_name, op, device, node_def)
       VALUES (?, ?, ?, ?, ?, snap(?))
     )sql");
@@ -262,7 +263,7 @@
   }
 
   Status SaveGraph() {
-    auto insert = db_->Prepare(R"sql(
+    auto insert = db_->PrepareOrDie(R"sql(
       INSERT INTO Graphs (graph_id, inserted_time, graph_def)
       VALUES (?, ?, snap(?))
     )sql");
@@ -291,14 +292,14 @@
         experiment_name_{experiment_name},
         run_name_{run_name},
         user_name_{user_name},
-        insert_tensor_{db_->Prepare(R"sql(
+        insert_tensor_{db_->PrepareOrDie(R"sql(
           INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor)
           VALUES (?, ?, ?, snap(?))
         )sql")} {}
 
   ~RunWriter() {
     if (run_id_ == kAbsent) return;
-    auto update = db_->Prepare(R"sql(
+    auto update = db_->PrepareOrDie(R"sql(
       UPDATE Runs SET finished_time = ? WHERE run_id = ?
     )sql");
     update.BindDouble(1, GetWallTime(env_));
@@ -331,7 +332,8 @@
     TF_RETURN_IF_ERROR(
         GraphSaver::Save(env_, db_.get(), &id_allocator_, g.get(), &graph_id));
     if (run_id_ != kAbsent) {
-      auto set = db_->Prepare("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
+      auto set =
+          db_->PrepareOrDie("UPDATE Runs SET graph_id = ? WHERE run_id = ?");
       set.BindInt(1, graph_id);
       set.BindInt(2, run_id_);
       TF_RETURN_IF_ERROR(set.StepAndReset());
@@ -350,14 +352,14 @@
     TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(tag_id));
     tag_ids_[tag_name] = *tag_id;
     if (!metadata.summary_description().empty()) {
-      SqliteStatement insert_description = db_->Prepare(R"sql(
+      SqliteStatement insert_description = db_->PrepareOrDie(R"sql(
         INSERT INTO Descriptions (id, description) VALUES (?, ?)
       )sql");
       insert_description.BindInt(1, *tag_id);
       insert_description.BindText(2, metadata.summary_description());
       TF_RETURN_IF_ERROR(insert_description.StepAndReset());
     }
-    SqliteStatement insert = db_->Prepare(R"sql(
+    SqliteStatement insert = db_->PrepareOrDie(R"sql(
       INSERT INTO Tags (
         run_id,
         tag_id,
@@ -387,7 +389,7 @@
  private:
   Status InitializeUser() {
     if (user_id_ != kAbsent || user_name_.empty()) return Status::OK();
-    SqliteStatement get = db_->Prepare(R"sql(
+    SqliteStatement get = db_->PrepareOrDie(R"sql(
       SELECT user_id FROM Users WHERE user_name = ?
     )sql");
     get.BindText(1, user_name_);
@@ -398,7 +400,7 @@
       return Status::OK();
     }
     TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&user_id_));
-    SqliteStatement insert = db_->Prepare(R"sql(
+    SqliteStatement insert = db_->PrepareOrDie(R"sql(
       INSERT INTO Users (user_id, user_name, inserted_time) VALUES (?, ?, ?)
     )sql");
     insert.BindInt(1, user_id_);
@@ -412,7 +414,7 @@
     if (experiment_name_.empty()) return Status::OK();
     if (experiment_id_ == kAbsent) {
       TF_RETURN_IF_ERROR(InitializeUser());
-      SqliteStatement get = db_->Prepare(R"sql(
+      SqliteStatement get = db_->PrepareOrDie(R"sql(
         SELECT
           experiment_id,
           started_time
@@ -432,7 +434,7 @@
       } else {
         TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&experiment_id_));
         experiment_started_time_ = computed_time;
-        SqliteStatement insert = db_->Prepare(R"sql(
+        SqliteStatement insert = db_->PrepareOrDie(R"sql(
           INSERT INTO Experiments (
             user_id,
             experiment_id,
@@ -451,7 +453,7 @@
     }
     if (computed_time < experiment_started_time_) {
       experiment_started_time_ = computed_time;
-      SqliteStatement update = db_->Prepare(R"sql(
+      SqliteStatement update = db_->PrepareOrDie(R"sql(
         UPDATE Experiments SET started_time = ? WHERE experiment_id = ?
       )sql");
       update.BindDouble(1, computed_time);
@@ -467,7 +469,7 @@
     if (run_id_ == kAbsent) {
       TF_RETURN_IF_ERROR(id_allocator_.CreateNewId(&run_id_));
       run_started_time_ = computed_time;
-      SqliteStatement insert = db_->Prepare(R"sql(
+      SqliteStatement insert = db_->PrepareOrDie(R"sql(
         INSERT OR REPLACE INTO Runs (
           experiment_id,
           run_id,
@@ -485,7 +487,7 @@
     }
     if (computed_time < run_started_time_) {
       run_started_time_ = computed_time;
-      SqliteStatement update = db_->Prepare(R"sql(
+      SqliteStatement update = db_->PrepareOrDie(R"sql(
         UPDATE Runs SET started_time = ? WHERE run_id = ?
       )sql");
       update.BindDouble(1, computed_time);
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index 5ea844b..cfc6192 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -48,7 +48,7 @@
 
 class SummaryDbWriterTest : public ::testing::Test {
  protected:
-  void SetUp() override { db_ = Sqlite::Open(":memory:").ValueOrDie(); }
+  void SetUp() override { db_ = Sqlite::OpenOrDie(":memory:"); }
 
   void TearDown() override {
     if (writer_ != nullptr) {
@@ -58,7 +58,7 @@
   }
 
   int64 QueryInt(const string& sql) {
-    SqliteStatement stmt = db_->Prepare(sql);
+    SqliteStatement stmt = db_->PrepareOrDie(sql);
     bool is_done;
     Status s = stmt.Step(&is_done);
     if (!s.ok() || is_done) {
@@ -69,7 +69,7 @@
   }
 
   double QueryDouble(const string& sql) {
-    SqliteStatement stmt = db_->Prepare(sql);
+    SqliteStatement stmt = db_->PrepareOrDie(sql);
     bool is_done;
     Status s = stmt.Step(&is_done);
     if (!s.ok() || is_done) {
@@ -80,7 +80,7 @@
   }
 
   string QueryString(const string& sql) {
-    SqliteStatement stmt = db_->Prepare(sql);
+    SqliteStatement stmt = db_->PrepareOrDie(sql);
     bool is_done;
     Status s = stmt.Step(&is_done);
     if (!s.ok() || is_done) {
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index ff100ca..848e5a6 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -176,7 +176,9 @@
     additional_deps = [
         ":tpu",
         "//tensorflow/python:client_testlib",
+        "//tensorflow/python:dtypes",
         "//tensorflow/python:framework",
+        "//tensorflow/python:layers",
     ],
 )
 
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 79cda18..8fec379 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -142,8 +142,9 @@
   def _AddOpInternal(self, op):
     # pylint: disable=protected-access
     if op.type in _BLACKLISTED_OPS:
-      raise ValueError("Operation of type %s (%s) is not supported on the TPU" %
-                       (op.type, op.name))
+      logging.error("Operation of type %s (%s) is not supported on the TPU. "
+                    "Execution will fail if this op is used in the graph. " %
+                    (op.type, op.name))
 
     if op.type in _NOT_IMPLEMENTED_OPS:
       self._unsupported_ops.append(op)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
index 2de5419..336d826 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py
@@ -20,8 +20,14 @@
 from __future__ import print_function
 
 from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import tpu_feed
+from tensorflow.contrib.tpu.python.tpu import training_loop
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.layers import convolutional
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import math_ops
 
 from tensorflow.python.platform import test
 
@@ -39,5 +45,36 @@
     self.assertTrue(control_flow_util.IsInXLAContext(z2.op))
 
 
+class TPULayerRewriteTest(test.TestCase):
+
+  def testUsingInfeedQueueWithRegularizer(self):
+    """Test that Layer regularizers can reference data created in loops."""
+
+    def make_regularizer(scale):
+      return lambda inputs: scale * math_ops.reduce_sum(math_ops.square(inputs))
+
+    def training_step(inputs, scale):
+      outputs = convolutional.conv2d(
+          inputs,
+          filters=16,
+          kernel_size=(3, 3),
+          data_format="channels_first",
+          kernel_regularizer=make_regularizer(scale))
+      loss = math_ops.reduce_mean(math_ops.square(outputs))
+      return loss.op
+
+    inputs = array_ops.zeros(shape=(128, 32, 32, 16))
+    scale = array_ops.ones(shape=())
+    infeed = tpu_feed.InfeedQueue(
+        tuple_types=[dtypes.float32, dtypes.float32],
+        tuple_shapes=[inputs.shape, scale.shape])
+
+    def loop():
+      return training_loop.repeat(5, training_step, infeed_queue=infeed)
+
+    # This should not throw an error.
+    tpu.rewrite(loop)
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index d032cf9..ef8ca3f 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -363,6 +363,23 @@
     ],
 )
 
+cc_library(
+    name = "abi",
+    srcs = ["platform/abi.cc"],
+    hdrs = ["platform/abi.h"],
+)
+
+cc_library(
+    name = "stacktrace_handler",
+    srcs = ["platform/stacktrace_handler.cc"],
+    hdrs = ["platform/stacktrace_handler.h"],
+    deps = [
+        ":abi",
+        ":lib",
+        ":lib_platform",
+    ],
+)
+
 # Test support library needed for all tests
 # This is currently public, but may be made internal in the
 # future.  Try to avoid depending on it.
@@ -2359,6 +2376,7 @@
     deps = [
         ":lib",
         ":lib_internal",
+        ":stacktrace_handler",
         ":test",  # buildcleaner: keep
         "//tensorflow/core/platform/default/build_config:test_main",
     ],
@@ -2429,6 +2447,7 @@
         "platform/net_test.cc",
         "platform/port_test.cc",
         "platform/profile_utils/cpu_utils_test.cc",
+        "platform/stacktrace_handler_test.cc",
         "platform/subprocess_test.cc",
     ],
     deps = [
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 286266a..51d7f98 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -152,7 +152,6 @@
   ~FunctionLibraryRuntimeImpl() override;
 
   Status Instantiate(const string& function_name, AttrSlice attrs,
-                     const InstantiateOptions& options,
                      Handle* handle) override;
 
   Status ReleaseHandle(Handle handle) override;
@@ -224,7 +223,7 @@
   Status GetOrCreateItem(Handle handle, Item** item);
   Status InstantiateSymbolicGradient(const NameAttrList& func,
                                      FunctionBody** g_body);
-  bool IsLocalTarget(const InstantiateOptions& options);
+  bool IsLocalTarget(const AttrSlice& attrs);
   AttrValueMap FixAttrs(const AttrSlice& attrs);
   void RunRemote(const Options& opts, Handle handle,
                  gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
@@ -353,8 +352,7 @@
   // Try to instantiate this function for the func/attr. Maybe it's
   // cached already.
   Handle handle;
-  TF_RETURN_IF_ERROR(
-      Instantiate(ndef.op(), AttrSlice(&ndef.attr()), {}, &handle));
+  TF_RETURN_IF_ERROR(Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle));
 
   const FunctionBody* fbody = GetFunctionBody(handle);
   CHECK_NOTNULL(fbody);
@@ -413,7 +411,7 @@
     // f is a user-defined function.
     Handle f_handle;
     TF_RETURN_IF_ERROR(
-        Instantiate(func.name(), AttrSlice(&func.attr()), {}, &f_handle));
+        Instantiate(func.name(), AttrSlice(&func.attr()), &f_handle));
     const FunctionBody* f_body = GetFunctionBody(f_handle);
     CHECK_NOTNULL(f_body);
     *g_body = SymbolicGradient(*f_body);
@@ -421,25 +419,42 @@
   return Status::OK();
 }
 
-bool FunctionLibraryRuntimeImpl::IsLocalTarget(
-    const InstantiateOptions& options) {
+bool FunctionLibraryRuntimeImpl::IsLocalTarget(const AttrSlice& attrs) {
   if (device_ == nullptr) return true;
-  if (options.target.empty()) return true;
+  string target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+  if (target.empty()) return true;
   Device* target_device;
-  if (!device_mgr_->LookupDevice(options.target, &target_device).ok()) {
+  if (!device_mgr_->LookupDevice(target, &target_device).ok()) {
     return false;
   }
   return target_device == device_;
 }
 
-Status FunctionLibraryRuntimeImpl::Instantiate(
-    const string& function_name, AttrSlice attrs,
-    const InstantiateOptions& options, Handle* handle) {
-  if (!IsLocalTarget(options)) {
-    return parent_->Instantiate(function_name, attrs, options, handle);
+AttrValueMap FunctionLibraryRuntimeImpl::FixAttrs(const AttrSlice& attrs) {
+  AttrValueMap value_map;
+  for (auto it : attrs) {
+    value_map[it.first] = it.second;
+  }
+  if (attrs.Find("_target") != nullptr) {
+    return value_map;
+  }
+  AttrValue v;
+  v.set_s(device_name_);
+  AddAttr("_target", v, &value_map);
+  return value_map;
+}
+
+Status FunctionLibraryRuntimeImpl::Instantiate(const string& function_name,
+                                               AttrSlice attrs,
+                                               Handle* handle) {
+  AttrValueMap value_map = FixAttrs(attrs);
+  AttrSlice new_attrs(&value_map);
+
+  if (!IsLocalTarget(new_attrs)) {
+    return parent_->Instantiate(function_name, new_attrs, handle);
   }
 
-  const string key = Canonicalize(function_name, attrs, options);
+  const string key = Canonicalize(function_name, new_attrs);
   *handle = parent_->GetHandle(key);
   if (*handle != kInvalidHandle) {
     return Status::OK();
@@ -448,7 +463,7 @@
   Status s;
   FunctionBody* fbody = nullptr;
   if (function_name == kGradientOp) {
-    const AttrValue* f = attrs.Find(kFuncAttr);
+    const AttrValue* f = new_attrs.Find(kFuncAttr);
     if (f == nullptr) {
       return errors::InvalidArgument("SymbolicGradient is missing attr: f");
     }
@@ -458,7 +473,7 @@
     }
     const string grad = lib_def_->FindGradient(func.name());
     if (!grad.empty()) {
-      return Instantiate(grad, AttrSlice(&func.attr()), options, handle);
+      return Instantiate(grad, AttrSlice(&func.attr()), handle);
     }
     TF_RETURN_IF_ERROR(InstantiateSymbolicGradient(func, &fbody));
   } else {
@@ -466,7 +481,7 @@
     if (fdef == nullptr) {
       return errors::NotFound("Function ", function_name, " is not defined.");
     }
-    TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, attrs, &fbody));
+    TF_RETURN_IF_ERROR(FunctionDefToBody(*fdef, new_attrs, &fbody));
   }
 
   {
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 2dacace..d4181ff 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -191,14 +191,11 @@
   Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
                      test::function::Attrs attrs,
                      FunctionLibraryRuntime::Handle* handle) {
-    return flr->Instantiate(name, attrs, handle);
-  }
-
-  Status Instantiate(FunctionLibraryRuntime* flr, const string& name,
-                     test::function::Attrs attrs,
-                     const FunctionLibraryRuntime::InstantiateOptions& options,
-                     FunctionLibraryRuntime::Handle* handle) {
-    return flr->Instantiate(name, attrs, options, handle);
+    Status status = flr->Instantiate(name, attrs, handle);
+    if (!status.ok()) {
+      return status;
+    }
+    return Status::OK();
   }
 
   Status InstantiateAndRun(FunctionLibraryRuntime* flr, const string& name,
@@ -1091,7 +1088,8 @@
 TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
   Init({test::function::FindDevice()});
   FunctionLibraryRuntime::Handle handle;
-  TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {}, {"/device:CPU:1"}, &handle));
+  TF_CHECK_OK(Instantiate(flr0_, "FindDevice", {{"_target", "/device:CPU:1"}},
+                          &handle));
 
   Tensor y;
   FunctionLibraryRuntime::Options opts;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 12947e2..53a1412 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -88,6 +88,16 @@
           std::move(custom_kernel_creator), nullptr /* cluster_flr */) {}
 
 /* static */
+string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
+    const AttrSlice& attrs) {
+  const AttrValue* value;
+  if (!attrs.Find("_target", &value).ok()) {
+    return "";
+  }
+  return DeviceNameUtils::CanonicalizeDeviceName(value->s());
+}
+
+/* static */
 Status ProcessFunctionLibraryRuntime::SendTensors(
     const string& source_device, const string& target_device,
     const string& key_prefix, int64 src_incarnation,
@@ -230,23 +240,22 @@
 
 Status ProcessFunctionLibraryRuntime::Instantiate(
     const string& function_name, AttrSlice attrs,
-    const FunctionLibraryRuntime::InstantiateOptions& options,
     FunctionLibraryRuntime::Handle* handle) {
   *handle = kInvalidHandle;
-  FunctionLibraryRuntime* flr = GetFLR(options.target);
+  string target = ObtainFunctionTarget(attrs);
+  FunctionLibraryRuntime* flr = GetFLR(target);
   if (flr != nullptr) {
-    return flr->Instantiate(function_name, attrs, options, handle);
+    return flr->Instantiate(function_name, attrs, handle);
   }
   if (parent_ == nullptr) {
     return errors::Internal(
-        "Currently don't support instantiating functions on device: ",
-        options.target);
+        "Currently don't support instantiating functions on device: ", target);
   }
   FunctionLibraryRuntime::Handle cluster_handle;
-  TF_RETURN_IF_ERROR(parent_->Instantiate(function_name, *lib_def_, attrs,
-                                          options, &cluster_handle));
+  TF_RETURN_IF_ERROR(
+      parent_->Instantiate(function_name, *lib_def_, attrs, &cluster_handle));
   string function_key = Canonicalize(function_name, attrs);
-  *handle = AddHandle(function_key, options.target, cluster_handle);
+  *handle = AddHandle(function_key, target, cluster_handle);
   return Status::OK();
 }
 
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 38003b7..3aa7b87 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -53,6 +53,11 @@
                                 const OptimizerOptions& optimizer_options,
                                 CustomKernelCreator custom_kernel_creator);
 
+  // Given a list of attrs on a function, extracts the "_target" attribute which
+  // indicates which device to run the function on. If it can't find the _target
+  // attribute, returns "". Canonicalizes the device name.
+  static string ObtainFunctionTarget(const AttrSlice& attrs);
+
   // Sends `tensors_to_send` from `source_device` to `target_device` using
   // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
   // Rendezvous. `device_context` should be the DeviceContext of the device
@@ -116,7 +121,6 @@
   // Allows for function_name to be instantiated on different devices
   // as specified in attrs.
   Status Instantiate(const string& function_name, AttrSlice attrs,
-                     const FunctionLibraryRuntime::InstantiateOptions& options,
                      FunctionLibraryRuntime::Handle* handle);
 
   // Delegates to the local FLR that owns state corresponding to `handle` and
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index f11b7a8..270e46d 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -49,12 +49,10 @@
   }
 
   Status Run(const string& name, FunctionLibraryRuntime::Options opts,
-             test::function::Attrs attrs,
-             const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts,
-             const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+             test::function::Attrs attrs, const std::vector<Tensor>& args,
+             std::vector<Tensor*> rets) {
     FunctionLibraryRuntime::Handle handle;
-    Status status =
-        proc_flr_->Instantiate(name, attrs, instantiate_opts, &handle);
+    Status status = proc_flr_->Instantiate(name, attrs, &handle);
     if (!status.ok()) {
       return status;
     }
@@ -144,6 +142,21 @@
   rendezvous_->Unref();
 }
 
+TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
+  AttrSlice empty_attrs;
+  string target =
+      ProcessFunctionLibraryRuntime::ObtainFunctionTarget(empty_attrs);
+  EXPECT_EQ("", target);
+
+  AttrValueMap attr_values;
+  AttrValue v;
+  v.set_s("/job:a/replica:0/task:0/cpu:1");
+  AddAttr("_target", v, &attr_values);
+  AttrSlice attrs(&attr_values);
+  target = ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+  EXPECT_EQ("/job:a/replica:0/task:0/device:CPU:1", target);
+}
+
 TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) {
   Init({});
   int64 incarnation;
@@ -165,8 +178,10 @@
   opts.remote_execution = true;
   auto x = test::AsTensor<float>({1, 2, 3, 4});
   Tensor y;
-  TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}},
-                  {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+  TF_CHECK_OK(
+      Run("XTimesTwo", opts,
+          {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+          {&y}));
   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
   rendezvous_->Unref();
 }
@@ -178,8 +193,8 @@
   opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   Tensor y;
-  TF_CHECK_OK(
-      Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:0"}, {}, {&y}));
+  TF_CHECK_OK(Run("FindDevice", opts,
+                  {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
   test::ExpectTensorEqual<string>(
       y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
                                 TensorShape({})));
@@ -194,11 +209,15 @@
   opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   Tensor y;
-  TF_CHECK_OK(Run("XTimesTwo", opts, {{"T", DT_FLOAT}},
-                  {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+  TF_CHECK_OK(
+      Run("XTimesTwo", opts,
+          {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+          {&y}));
   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
-  TF_CHECK_OK(Run("XTimesFour", opts, {{"T", DT_FLOAT}},
-                  {"/job:a/replica:0/task:0/cpu:0"}, {x}, {&y}));
+  TF_CHECK_OK(
+      Run("XTimesFour", opts,
+          {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
+          {&y}));
   test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
   rendezvous_->Unref();
 }
@@ -210,13 +229,13 @@
   opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   Tensor y;
-  TF_CHECK_OK(
-      Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y}));
+  TF_CHECK_OK(Run("FindDevice", opts,
+                  {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
   test::ExpectTensorEqual<string>(
       y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
                                 TensorShape({})));
-  TF_CHECK_OK(
-      Run("FindDevice", opts, {}, {"/job:a/replica:0/task:0/cpu:1"}, {}, {&y}));
+  TF_CHECK_OK(Run("FindDevice", opts,
+                  {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
   test::ExpectTensorEqual<string>(
       y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
                                 TensorShape({})));
@@ -230,13 +249,11 @@
   opts.rendezvous = rendezvous_;
   opts.remote_execution = true;
   Tensor y;
-  TF_CHECK_OK(Run("FindDevice", opts, {},
-                  {"/job:a/replica:0/task:0/device:CPU:0"}, {}, {&y}));
+  TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:0"}}, {}, {&y}));
   test::ExpectTensorEqual<string>(
       y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:0"},
                                 TensorShape({})));
-  TF_CHECK_OK(Run("FindDevice", opts, {},
-                  {"/job:a/replica:0/task:0/device:CPU:1"}, {}, {&y}));
+  TF_CHECK_OK(Run("FindDevice", opts, {{"_target", "/cpu:1"}}, {}, {&y}));
   test::ExpectTensorEqual<string>(
       y, test::AsTensor<string>({"/job:a/replica:0/task:0/device:CPU:1"},
                                 TensorShape({})));
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
index 3a8d591..d84b69d 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc
@@ -26,10 +26,10 @@
 
 /* static */
 Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
-    const OpDef& sig, AttrSlice attrs,
-    const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
+    const OpDef& sig, AttrSlice attrs, GraphDef* g,
     std::vector<string>* send_keys, std::vector<string>* recv_keys) {
-  const string& target = options.target;
+  const string& target =
+      ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
   // Construct recv nodes for each input argument.
   int i = 0;
   for (const auto& in : sig.input_arg()) {
@@ -119,16 +119,16 @@
 
 Status ClusterFunctionLibraryRuntime::Instantiate(
     const string& function_name, const FunctionLibraryDefinition& lib_def,
-    AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
-    FunctionLibraryRuntime::LocalHandle* handle) {
-  WorkerInterface* wi =
-      worker_session_->worker_cache->CreateWorker(options.target);
+    AttrSlice attrs, FunctionLibraryRuntime::LocalHandle* handle) {
+  const string& target =
+      ProcessFunctionLibraryRuntime::ObtainFunctionTarget(attrs);
+  WorkerInterface* wi = worker_session_->worker_cache->CreateWorker(target);
 
   if (wi == nullptr) {
     std::vector<string> workers;
     worker_session_->worker_cache->ListWorkers(&workers);
     return errors::InvalidArgument(
-        "Could not find worker with target: ", options.target,
+        "Could not find worker with target: ", target,
         " Available workers: ", str_util::Join(workers, ", "));
   }
 
@@ -137,8 +137,8 @@
   const OpDef& sig = fdef->signature();
   GraphDef gdef;
   std::vector<string> send_keys, recv_keys;
-  TF_RETURN_IF_ERROR(ConstructFunctionGraph(sig, attrs, options, &gdef,
-                                            &send_keys, &recv_keys));
+  TF_RETURN_IF_ERROR(
+      ConstructFunctionGraph(sig, attrs, &gdef, &send_keys, &recv_keys));
   *gdef.mutable_library() = lib_def.ToProto();
 
   RegisterGraphRequest req;
@@ -152,8 +152,8 @@
 
   mutex_lock l(mu_);
   *handle = function_data_.size();
-  function_data_.push_back(FunctionData(resp.graph_handle(), options.target, wi,
-                                        send_keys, recv_keys));
+  function_data_.push_back(
+      FunctionData(resp.graph_handle(), target, wi, send_keys, recv_keys));
   return Status::OK();
 }
 
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
index 3deb80d..dd4ea68 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h
@@ -34,7 +34,6 @@
 
   Status Instantiate(const string& function_name,
                      const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
-                     const FunctionLibraryRuntime::InstantiateOptions& options,
                      FunctionLibraryRuntime::LocalHandle* handle) override;
 
   void Run(const FunctionLibraryRuntime::Options& opts,
@@ -43,10 +42,10 @@
            FunctionLibraryRuntime::DoneCallback done) override;
 
  private:
-  static Status ConstructFunctionGraph(
-      const OpDef& sig, AttrSlice attrs,
-      const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
-      std::vector<string>* send_keys, std::vector<string>* recv_keys);
+  static Status ConstructFunctionGraph(const OpDef& sig, AttrSlice attrs,
+                                       GraphDef* g,
+                                       std::vector<string>* send_keys,
+                                       std::vector<string>* recv_keys);
   friend class ClusterFunctionLibraryRuntimeTest;
 
   mutable mutex mu_;
diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
index 98512bc..6dd8b9e 100644
--- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
+++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc
@@ -47,31 +47,30 @@
         new ClusterFunctionLibraryRuntime(worker_session_.get()));
   }
 
-  Status ConstructFunctionGraphHelper(
-      const OpDef& sig, test::function::Attrs attrs,
-      const FunctionLibraryRuntime::InstantiateOptions& options, GraphDef* g,
-      std::vector<string>* send_keys, std::vector<string>* recv_keys) {
+  Status ConstructFunctionGraphHelper(const OpDef& sig,
+                                      test::function::Attrs attrs, GraphDef* g,
+                                      std::vector<string>* send_keys,
+                                      std::vector<string>* recv_keys) {
     return ClusterFunctionLibraryRuntime::ConstructFunctionGraph(
-        sig, attrs, options, g, send_keys, recv_keys);
+        sig, attrs, g, send_keys, recv_keys);
   }
 
   Status Instantiate(const string& function_name,
                      const FunctionLibraryDefinition& lib_def,
                      test::function::Attrs attrs,
-                     const FunctionLibraryRuntime::InstantiateOptions& options,
                      FunctionLibraryRuntime::LocalHandle* local_handle) {
-    return cluster_flr_->Instantiate(function_name, lib_def, attrs, options,
+    return cluster_flr_->Instantiate(function_name, lib_def, attrs,
                                      local_handle);
   }
 
-  Status InstantiateAndRun(
-      const string& function_name, const FunctionLibraryDefinition& lib_def,
-      test::function::Attrs attrs,
-      const FunctionLibraryRuntime::InstantiateOptions& options,
-      const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
+  Status InstantiateAndRun(const string& function_name,
+                           const FunctionLibraryDefinition& lib_def,
+                           test::function::Attrs attrs,
+                           const std::vector<Tensor>& args,
+                           std::vector<Tensor*> rets) {
     FunctionLibraryRuntime::LocalHandle handle;
-    TF_RETURN_IF_ERROR(cluster_flr_->Instantiate(function_name, lib_def, attrs,
-                                                 options, &handle));
+    TF_RETURN_IF_ERROR(
+        cluster_flr_->Instantiate(function_name, lib_def, attrs, &handle));
 
     Notification done;
     FunctionLibraryRuntime::Options opts;
@@ -104,9 +103,9 @@
   GraphDef actual;
   std::vector<string> send_keys, recv_keys;
   TF_CHECK_OK(ConstructFunctionGraphHelper(
-      test::function::Swap().signature(), {{"T", DT_FLOAT}},
-      {"/job:a/replica:0/task:0/device:CPU:0"}, &actual, &send_keys,
-      &recv_keys));
+      test::function::Swap().signature(),
+      {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, &actual,
+      &send_keys, &recv_keys));
   GraphDef expected;
   protobuf::TextFormat::ParseFromString(R"(
 node {
@@ -206,7 +205,7 @@
   attr {
     key: "_target"
     value {
-      s: "/job:a/replica:0/task:0/device:CPU:0"
+      s: "/job:a/replica:0/task:0/cpu:0"
     }
   }
 }
@@ -310,9 +309,9 @@
 
   Tensor y;
   auto x = test::AsTensor<int32>({1, 2, 3, 4});
-  TF_EXPECT_OK(InstantiateAndRun("XTimesTwoInt32", lib_def, {},
-                                 {"/job:localhost/replica:0/task:1/cpu:0"}, {x},
-                                 {&y}));
+  TF_EXPECT_OK(InstantiateAndRun(
+      "XTimesTwoInt32", lib_def,
+      {{"_target", "/job:localhost/replica:0/task:1/cpu:0"}}, {x}, {&y}));
   test::ExpectTensorEqual<int32>(y, test::AsTensor<int32>({2, 4, 6, 8}));
 }
 
@@ -325,9 +324,10 @@
   Tensor y1, y2;
   auto x1 = test::AsTensor<float>({1, 2, 3, 4});
   auto x2 = test::AsTensor<float>({4, 3, 2, 1});
-  TF_EXPECT_OK(InstantiateAndRun("Swap", lib_def, {{"T", DT_FLOAT}},
-                                 {"/job:localhost/replica:0/task:1/cpu:0"},
-                                 {x1, x2}, {&y1, &y2}));
+  TF_EXPECT_OK(InstantiateAndRun(
+      "Swap", lib_def,
+      {{"T", DT_FLOAT}, {"_target", "/job:localhost/replica:0/task:1/cpu:0"}},
+      {x1, x2}, {&y1, &y2}));
   test::ExpectTensorEqual<float>(y1, test::AsTensor<float>({4, 3, 2, 1}));
   test::ExpectTensorEqual<float>(y2, test::AsTensor<float>({1, 2, 3, 4}));
 }
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 7830154..d757e96 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -795,17 +795,12 @@
   return h;
 }
 
-string Canonicalize(const string& funcname, AttrSlice attrs,
-                    const FunctionLibraryRuntime::InstantiateOptions& options) {
+string Canonicalize(const string& funcname, AttrSlice attrs) {
   std::vector<string> entries;
-  entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1));
+  entries.reserve(attrs.size());
   for (auto p : attrs) {
     entries.push_back(strings::StrCat(p.first, "=", Print(p.second)));
   }
-  if (!options.target.empty()) {
-    entries.push_back(
-        strings::StrCat("_target", "=", str_util::CEscape(options.target)));
-  }
   std::sort(entries.begin(), entries.end());
   return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]");
 }
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e5d0e49..1a579ab 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -234,6 +234,15 @@
 // same.
 uint64 FunctionDefHash(const FunctionDef& fdef);
 
+// Returns a canonicalized string for the instantiation of the
+// function of the given "name" and attributes "attrs".
+//
+// The returned string is guaranteed to be stable within one address
+// space. But it may be change as the implementation
+// evolves. Therefore, it should not be persisted or compared across
+// address spaces.
+string Canonicalize(const string& funcname, AttrSlice attrs);
+
 class CallFrameInterface {
  public:
   virtual ~CallFrameInterface() {}
@@ -409,23 +418,9 @@
   //
   // Returns OK and fills in "handle" if the instantiation succeeds.
   // Otherwise returns an error and "handle" is undefined.
-  struct InstantiateOptions {
-    // The canonical device name of the device on which the function
-    // should be instantiated. If empty, the function will be
-    // instantiated on the local device.
-    string target;
-
-    // TODO(b/70352992): Add an API for allowing a different
-    // FunctionLibraryDefinition to be overlaid on this runtime's library.
-  };
   typedef uint64 Handle;
   virtual Status Instantiate(const string& function_name, AttrSlice attrs,
-                             const InstantiateOptions& options,
                              Handle* handle) = 0;
-  Status Instantiate(const string& function_name, AttrSlice attrs,
-                     Handle* handle) {
-    return Instantiate(function_name, attrs, {}, handle);
-  }
 
   // Releases state associated with the handle.
   virtual Status ReleaseHandle(Handle handle) = 0;
@@ -507,19 +502,6 @@
   typedef uint64 LocalHandle;
 };
 
-// Returns a canonicalized string for the instantiation of the
-// function of the given "name", attributes "attrs", and "options".
-//
-// The returned string is guaranteed to be stable within one address
-// space. But it may be change as the implementation
-// evolves. Therefore, it should not be persisted or compared across
-// address spaces.
-string Canonicalize(const string& funcname, AttrSlice attrs,
-                    const FunctionLibraryRuntime::InstantiateOptions& options);
-inline string Canonicalize(const string& funcname, AttrSlice attrs) {
-  return Canonicalize(funcname, attrs, {});
-}
-
 const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
 const FunctionLibraryRuntime::LocalHandle kInvalidLocalHandle = -1;
 typedef std::function<Status(FunctionLibraryRuntime*, const NodeDef&,
@@ -532,11 +514,10 @@
   virtual ~DistributedFunctionLibraryRuntime() {}
 
   // The _target attr in attrs determines where the function is instantiated.
-  virtual Status Instantiate(
-      const string& function_name, const FunctionLibraryDefinition& lib_def,
-      AttrSlice attrs,
-      const FunctionLibraryRuntime::InstantiateOptions& options,
-      FunctionLibraryRuntime::LocalHandle* handle) = 0;
+  virtual Status Instantiate(const string& function_name,
+                             const FunctionLibraryDefinition& lib_def,
+                             AttrSlice attrs,
+                             FunctionLibraryRuntime::LocalHandle* handle) = 0;
 
   // opts.runner isn't used for execution.
   virtual void Run(const FunctionLibraryRuntime::Options& opts,
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 4771840..743c76d 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -347,6 +347,21 @@
 
 }  // namespace
 
+Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
+                        int input_port, DataType* input_type) {
+  DataTypeVector input_types;
+  for (const auto& arg : op_def.input_arg()) {
+    TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
+    if (input_types.size() > input_port) {
+      const DataType dtype = input_types[input_port];
+      *input_type = dtype;
+      return Status::OK();
+    }
+  }
+  return errors::InvalidArgument("Input ", input_port, " not found for node ",
+                                 node_def.name());
+}
+
 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
                          DataTypeVector* inputs, DataTypeVector* outputs) {
   for (const auto& arg : op_def.input_arg()) {
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 4e98522..812cf1b 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -239,6 +239,11 @@
 // REQUIRES: Must not use the returned value beyond the lifetime of node_def.
 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
 
+// Computes the input type for a specific node input.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
+                        int input_port, DataType* input_type);
+
 // Computes the input and output types for a specific node.
 // REQUIRES: ValidateOpDef(op_def).ok()
 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc
index df936ef..2684100 100644
--- a/tensorflow/core/grappler/clusters/single_machine_test.cc
+++ b/tensorflow/core/grappler/clusters/single_machine_test.cc
@@ -58,6 +58,10 @@
   std::unique_ptr<SingleMachine> cluster_;
 };
 
+TEST_F(SingleMachineTest, ClusterType) {
+  CHECK_EQ("single_machine", cluster_->type());
+}
+
 TEST_F(SingleMachineTest, CostModel) {
   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
                                           cluster_->GetDeviceNames());
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc
index fd925a6..357b306 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc
@@ -56,6 +56,10 @@
   std::unique_ptr<VirtualCluster> cluster_;
 };
 
+TEST_F(VirtualClusterTest, ClusterType) {
+  CHECK_EQ("virtual", cluster_->type());
+}
+
 TEST_F(VirtualClusterTest, CostModel) {
   TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
                                           cluster_->GetDeviceNames());
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc
index 6022c47..3168758 100644
--- a/tensorflow/core/grappler/costs/graph_memory.cc
+++ b/tensorflow/core/grappler/costs/graph_memory.cc
@@ -32,7 +32,17 @@
     const std::unordered_map<string, DeviceProperties>& devices) {
   VirtualCluster cluster(devices);
   TF_RETURN_IF_ERROR(cluster.Provision());
-  return InferDynamically(&cluster);
+  TF_RETURN_IF_ERROR(cluster.Initialize(item_));
+  RunMetadata metadata;
+  Status s = cluster.Run(item_.graph, item_.feed, item_.fetch, &metadata);
+  // The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects
+  // that the model would run out of memory. We still get the metadata we need
+  // out of the simulation, so we just ignore this error.
+  if (!s.ok() && s.code() != error::RESOURCE_EXHAUSTED) {
+    return s;
+  }
+  InferFromTrace(metadata.step_stats());
+  return Status::OK();
 }
 
 Status GraphMemory::InferDynamically(Cluster* cluster) {
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index a24310a..63dfade 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -29,8 +29,8 @@
 class GraphView {
  public:
   struct Port {
-    NodeDef* node;
-    int port_id;
+    NodeDef* node = nullptr;
+    int port_id = -1;
 
     bool operator==(const Port& other) const {
       return node == other.node && port_id == other.port_id;
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index e557adc..791ad34 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -279,6 +279,7 @@
         ":graph_optimizer",
         ":graph_rewriter",
         ":static_schedule",
+        "//tensorflow/core:framework",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:graph_view",
         "//tensorflow/core/grappler:grappler_item",
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 37610d2..9b7f572 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -1699,10 +1699,6 @@
     int input_port;
     auto input = node_map_->GetNode(node_->input(0));
     ParseNodeName(node_->input(0), &input_port);
-    if (IsTransposeNCHWToNHWC(input->name())) {
-      input = node_map_->GetNode(input->input(0));
-      ParseNodeName(input->input(0), &input_port);
-    }
     if (input->attr().find("_output_shapes") != input->attr().end()) {
       auto shape = input->attr().at("_output_shapes").list().shape(input_port);
       if (shape.dim_size() != 4) {
@@ -1745,13 +1741,28 @@
     int port;
     ParseNodeName(node_->input(0), &port);
     return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
-           IsPortDimsFour(*input0, port) && IsAlongAllFourDims() && IsOnGPU();
+           IsPortDimsFour(*input0, port) && IsReduceAxisSupported() &&
+           IsOnGPU();
+  }
+
+  Status CustomizedProcessing() override {
+    if (IsAlongNHW() || IsAlongHW() || IsAlongC()) {
+      DataType dtype = node_->attr().at("Tidx").type();
+      TF_RETURN_IF_ERROR(
+          UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
+    }
+    return Status::OK();
   }
 
   Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
 
  private:
-  bool IsAlongAllFourDims() const {
+  bool IsReduceAxisSupported() const {
+    return IsAlongAllFourDims() || IsAlongHWC() ||
+           ((IsAlongNHW() || IsAlongHW() || IsAlongC()) && !KeepDims());
+  }
+
+  bool IsAlongAxis(const std::vector<int>& axis) const {
     auto axis_node = node_map_->GetNode(node_->input(1));
     if (!IsConstant(*axis_node)) {
       return false;
@@ -1762,15 +1773,28 @@
       if (!success) {
         LOG(ERROR) << "Failed to parse TensorProto.";
       }
-      if (tensor.dims() == 1 && tensor.dim_size(0) == 4) {
-        if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 &&
-            tensor.flat<int>()(2) == 2 && tensor.flat<int>()(3) == 3) {
-          return true;
+      if (tensor.dims() == 1 && tensor.dim_size(0) == axis.size()) {
+        bool along_axis = true;
+        for (int i = 0; i < axis.size(); i++) {
+          along_axis = along_axis && (tensor.flat<int>()(i) == axis[i]);
         }
+        if (along_axis) return true;
       }
     }
     return false;
   }
+
+  bool IsAlongAllFourDims() const { return IsAlongAxis({0, 1, 2, 3}); }
+
+  bool IsAlongHWC() const { return IsAlongAxis({1, 2, 3}); }
+
+  bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
+
+  bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
+
+  bool IsAlongC() const { return IsAlongAxis({3}); }
+
+  bool KeepDims() const { return node_->attr().at("keep_dims").b(); }
 };
 
 class SwitchProcessor : public AgnosticNodeProcessor {
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 1420fdb..bb4839d 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -23,6 +23,7 @@
 
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/grappler/costs/graph_memory.h"
 #include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -480,8 +481,19 @@
   }
 }
 
-std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
-                                            GraphDef* graph) {
+Status BuildSwapPair(NodeDef* node, int input_to_swap, GraphDef* graph,
+                     std::pair<NodeDef*, NodeDef*>* swap_pair) {
+  const OpDef* op_def;
+  TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
+  DataType input_type;
+  TF_RETURN_IF_ERROR(
+      InputTypeForNode(*node, *op_def, input_to_swap, &input_type));
+  if (IsRefType(input_type)) {
+    return errors::InvalidArgument("Can't swap input ", input_to_swap,
+                                   " of node ", node->name(),
+                                   " since it expects a reference");
+  }
+
   string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
 
   // Force the tensor to be copied to cpu.
@@ -501,10 +513,11 @@
   (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
   (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
 
-  const DataType input_type = node->attr().at("T").type();
   (*swap_in_node->mutable_attr())["T"].set_type(input_type);
   (*swap_out_node->mutable_attr())["T"].set_type(input_type);
-  return std::make_pair(swap_out_node, swap_in_node);
+  *swap_pair = std::make_pair(swap_out_node, swap_in_node);
+
+  return Status::OK();
 }
 
 static int64 EstimateSize(const OpInfo::TensorProperties& t) {
@@ -568,9 +581,12 @@
   max_trigger_time -= swap_info.time_to_swap;
 
   std::map<Costs::NanoSeconds, const NodeDef*> candidates;
+  std::set<string> already_processed;
+
   while (!possible_inputs.empty()) {
     const string input_node_name = *possible_inputs.begin();
     possible_inputs.erase(possible_inputs.begin());
+    already_processed.insert(input_node_name);
     auto it1 = name_map.find(input_node_name);
     if (it1 == name_map.end()) {
       return nullptr;
@@ -579,7 +595,7 @@
     // Don't jump over frames, since adding a control dependency from one frame
     // to the next isn't supported. Don't go through branches, since we don't
     // know whether they'll be executed or not.
-    if (IsNextIteration(*input_node) || IsSwitch(*input_node) ||
+    if (ModifiesFrameInfo(*input_node) || IsSwitch(*input_node) ||
         IsMerge(*input_node)) {
       continue;
     }
@@ -591,7 +607,10 @@
       candidates[it2->second] = input_node;
     } else {
       for (const string& fanin : input_node->input()) {
-        possible_inputs.insert(NodeName(fanin));
+        string name = NodeName(fanin);
+        if (already_processed.find(name) == already_processed.end()) {
+          possible_inputs.insert(name);
+        }
       }
     }
   }
@@ -605,13 +624,31 @@
   return nullptr;
 }
 
+static bool IsSwappable(GraphView::InputPort input) {
+  const NodeDef& node = *input.node;
+
+  const OpDef* op_def;
+  if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok()) {
+    return false;
+  }
+
+  DataType dtype;
+  if (!InputTypeForNode(*input.node, *op_def, input.port_id, &dtype).ok()) {
+    return false;
+  }
+
+  return !IsRefType(dtype);
+}
+
 static void IdentifySwappingCandidates(Cluster* cluster,
                                        const GrapplerItem& item,
                                        GraphDef* optimized_graph) {
   GraphMemory memory(item);
   const std::unordered_map<string, DeviceProperties>& devices =
       cluster->GetDevices();
-  if (!memory.InferStatically(devices).ok()) {
+  Status s = memory.InferStatically(devices);
+  if (!s.ok()) {
+    VLOG(1) << "Failed to infer memory usage: " << s.error_message();
     return;
   }
 
@@ -622,24 +659,36 @@
       continue;
     }
     if (prop.memory_size() <= 0) {
+      VLOG(1) << "Peak memory usage unknown for device " << name;
       continue;
     }
     const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage(name);
+
     if (mem_usage.used_memory <= prop.memory_size()) {
       continue;
     }
     int64 required_savings = mem_usage.used_memory - prop.memory_size();
     // TODO(bsteiner): sort the tensors by how long they're live.
 
-    std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
-    if (!EstimateEarliestExecutionTimes(item, cluster, &execution_times).ok()) {
-      return;
+    std::unordered_map<string, Costs::NanoSeconds> execution_times;
+    {
+      std::unordered_map<const NodeDef*, Costs::NanoSeconds>
+          tmp_execution_times;
+      if (!EstimateEarliestExecutionTimes(item, cluster, &tmp_execution_times)
+               .ok()) {
+        return;
+      }
+      for (const auto& exec_time : tmp_execution_times) {
+        execution_times.emplace(exec_time.first->name(), exec_time.second);
+      }
     }
+
     GraphView graph(optimized_graph);
     for (const auto& live_tensor : mem_usage.live_tensors) {
       if (live_tensor.deallocation_time - live_tensor.allocation_time <=
           Costs::Duration(1e6)) {
         // Not enough time to swap.
+        VLOG(1) << "Not enough time to swap: skipping " << live_tensor.node;
         continue;
       }
       if (live_tensor.memory_used <= 1024) {
@@ -651,7 +700,10 @@
       GraphView::OutputPort port =
           graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
       for (GraphView::InputPort input : graph.GetFanout(port)) {
-        auto it = execution_times.find(input.node);
+        if (!IsSwappable(input)) {
+          continue;
+        }
+        auto it = execution_times.find(input.node->name());
         if (it != execution_times.end()) {
           if (it->second > execution_time) {
             fanout_to_swap = input;
@@ -661,15 +713,23 @@
       }
       // Annotate the fanout to request the tensor to be swapped if it's not
       // already been done.
-      AttrValue& val = (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
       bool found = false;
-      for (int port_id : val.list().i()) {
-        if (port_id == fanout_to_swap.port_id) {
-          found = true;
-          break;
+      if (!fanout_to_swap.node) {
+        continue;
+      }
+      auto it = fanout_to_swap.node->attr().find("_swap_to_host");
+      if (it != fanout_to_swap.node->attr().end()) {
+        const AttrValue& val = it->second;
+        for (int port_id : val.list().i()) {
+          if (port_id == fanout_to_swap.port_id) {
+            found = true;
+            break;
+          }
         }
       }
       if (!found) {
+        AttrValue& val =
+            (*fanout_to_swap.node->mutable_attr())["_swap_to_host"];
         val.mutable_list()->add_i(fanout_to_swap.port_id);
         required_savings -= live_tensor.memory_used;
         if (required_savings < 0) {
@@ -688,7 +748,8 @@
                              recomputation_targets_name_prefix_,
                              optimized_graph, item);
 
-  if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS) {
+  if (optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS &&
+      cluster != nullptr) {
     IdentifySwappingCandidates(cluster, item, optimized_graph);
   }
 
@@ -713,7 +774,6 @@
     return Status::OK();
   }
 
-  {
     // Estimate the size of the data to swap for each node.
     GraphProperties properties(item);
     TF_RETURN_IF_ERROR(properties.InferStatically(true));
@@ -730,7 +790,6 @@
       // Let's assume we're going to swap over PCIe running at 16 GBps.
       swap_info.time_to_swap = bytes_to_swap / 16;
     }
-  }
 
   std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
   TF_RETURN_IF_ERROR(
@@ -743,7 +802,7 @@
 
   for (auto& swap : nodes_to_swap) {
     NodeDef* node = swap.first;
-    SwapInfo& swap_info = swap.second;
+    const SwapInfo& swap_info = swap.second;
 
     // Make sure the tensor isn't swapped back in right away: look for node that
     // will execute just before we need to swap the data back, and add a control
@@ -755,8 +814,10 @@
     }
     // Swap all the tensors that are marked with the 'swap_to_host' attribute.
     for (int input_id : swap_info.inputs_to_swap) {
-      std::pair<NodeDef*, NodeDef*> swap_nodes =
-          BuildSwapPair(node, input_id, optimized_graph);
+      std::pair<NodeDef*, NodeDef*> swap_nodes;
+      if (!BuildSwapPair(node, input_id, optimized_graph, &swap_nodes).ok()) {
+        continue;
+      }
       *swap_nodes.first->add_input() = node->input(input_id);
       *node->mutable_input(input_id) = swap_nodes.second->name();
 
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
index 6fa4731..6448d1e 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
@@ -201,8 +201,16 @@
     cpu_device.set_frequency(1000);
     cpu_device.set_num_cores(4);
     cpu_device.set_bandwidth(32);
+    DeviceProperties gpu_device;
+    gpu_device.set_type("GPU");
+    gpu_device.set_frequency(1000);
+    gpu_device.set_num_cores(24);
+    gpu_device.set_bandwidth(128);
+    gpu_device.set_memory_size(1024 * 1024);
+    gpu_device.mutable_environment()->insert({"architecture", "6"});
     std::unordered_map<string, DeviceProperties> devices;
     devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device;
+    devices["/job:localhost/replica:0/task:0/gpu:0"] = gpu_device;
     return std::unique_ptr<VirtualCluster>(new VirtualCluster(devices));
   }
 };
@@ -252,6 +260,74 @@
   EXPECT_EQ("^c", swap_in.input(1));
 }
 
+TEST_F(MemoryOptimizerTest, SwappingHeuristics) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+                           {128, 128, 8}, DT_FLOAT);
+  Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
+  Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
+  Output d = ops::Identity(s.WithOpName("d").WithDevice("/gpu:0"), {a});
+  Output axis = ops::Const(s.WithOpName("axis"), 0);
+  Output e =
+      ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"e"};
+
+  std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
+
+  MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
+  GraphDef output;
+  Status status = optimizer.Optimize(cluster.get(), item, &output);
+  TF_EXPECT_OK(status);
+
+  for (const auto& node : output.node()) {
+    if (node.name() == "e") {
+      EXPECT_TRUE(node.attr().count("_swap_to_host") > 0);
+      const AttrValue& val = node.attr().at("_swap_to_host");
+      EXPECT_TRUE(val.has_list());
+      std::set<int> inputs_to_swap;
+      for (int64 input_id : val.list().i()) {
+        inputs_to_swap.insert(input_id);
+      }
+      EXPECT_EQ(std::set<int>({0, 1, 2}), inputs_to_swap);
+    }
+  }
+}
+
+TEST_F(MemoryOptimizerTest, UnswappableInputs) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output a = ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"),
+                           {128, 128, 8}, DT_FLOAT);
+  Output b = ops::Identity(s.WithOpName("b").WithDevice("/gpu:0"), {a});
+  Output c = ops::Identity(s.WithOpName("c").WithDevice("/gpu:0"), {a});
+  Output index = ops::Const(s.WithOpName("index"), {0});
+  Output indices = ops::Tile(s.WithOpName("indices"), index, {128});
+  Output d =
+      ops::ScatterAdd(s.WithOpName("d").WithDevice("/gpu:0"), a, indices, c);
+  Output axis = ops::Const(s.WithOpName("axis"), 0);
+  Output e =
+      ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"e"};
+
+  std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
+
+  MemoryOptimizer optimizer(RewriterConfig::SWAPPING_HEURISTICS);
+  GraphDef output;
+  Status status = optimizer.Optimize(cluster.get(), item, &output);
+  TF_EXPECT_OK(status);
+
+  for (const auto& node : output.node()) {
+    if (node.name() == "d") {
+      EXPECT_EQ(0, node.attr().count("_swap_to_host"));
+    }
+  }
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 6899767..6f7a0fd 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -164,7 +164,7 @@
                 buffer_element.value.size()));
             for (size_t j = 0; j < buffer_element.value.size(); j++) {
               TF_RETURN_IF_ERROR(writer->WriteTensor(
-                  strings::StrCat("buffer[", i, "][", j, "]"),
+                  full_name(strings::StrCat("buffer[", i, "][", j, "]")),
                   buffer_element.value[j]));
             }
           }
@@ -201,7 +201,7 @@
             for (size_t j = 0; j < value_size; j++) {
               buffer_element.value.emplace_back();
               TF_RETURN_IF_ERROR(reader->ReadTensor(
-                  strings::StrCat("buffer[", i, "][", j, "]"),
+                  full_name(strings::StrCat("buffer[", i, "][", j, "]")),
                   &buffer_element.value.back()));
             }
           }
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
index abe3126..71af30e 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
@@ -14,6 +14,7 @@
 ==============================================================================*/
 #include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h"
 
+#include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 
 namespace tensorflow {
@@ -40,21 +41,15 @@
 }
 
 Status SqliteQueryConnection::Close() {
-  Status s;
-  s.Update(stmt_.Close());
-  s.Update(db_->Close());
-  return s;
+  stmt_ = SqliteStatement();
+  db_.reset();
+  return Status::OK();
 }
 
 Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
                                       bool* end_of_sequence) {
-  if (!stmt_) {
-    Status s = PrepareQuery();
-    if (!s.ok()) {
-      return s;
-    }
-  }
-  Status s = stmt_.Step(end_of_sequence);
+  if (!stmt_) TF_RETURN_IF_ERROR(PrepareQuery());
+  TF_RETURN_IF_ERROR(stmt_.Step(end_of_sequence));
   if (!*end_of_sequence) {
     for (int i = 0; i < column_count_; i++) {
       DataType dt = output_types_[i];
@@ -63,64 +58,48 @@
       out_tensors->emplace_back(std::move(tensor));
     }
   }
-  return s;
+  return Status::OK();
 }
 
 Status SqliteQueryConnection::PrepareQuery() {
-  stmt_ = db_->Prepare(query_);
-  Status s = stmt_.status();
-  if (s.ok()) {
-    int column_count = stmt_.ColumnCount();
-    if (column_count != output_types_.size()) {
-      return errors::InvalidArgument(tensorflow::strings::Printf(
-          "The number of columns in query (%d) must match the number of "
-          "elements in output_types (%zu).",
-          column_count, output_types_.size()));
-    }
-    column_count_ = column_count;
+  auto prep = db_->Prepare(query_);
+  TF_RETURN_IF_ERROR(prep.status());
+  int column_count = prep.ValueOrDie().ColumnCount();
+  if (column_count != output_types_.size()) {
+    return errors::InvalidArgument(tensorflow::strings::Printf(
+        "The number of columns in query (%d) must match the number of "
+        "elements in output_types (%zu).",
+        column_count, output_types_.size()));
   }
-  return s;
+  stmt_ = prep.ConsumeValueOrDie();
+  column_count_ = column_count;
+  return Status::OK();
 }
 
 void SqliteQueryConnection::FillTensorWithResultSetEntry(
     const DataType& data_type, int column_index, Tensor* tensor) {
+#define CASE(T, M)                                                 \
+  case DataTypeToEnum<T>::value:                                   \
+    tensor->scalar<T>()() = static_cast<T>(stmt_.M(column_index)); \
+    break;
+#define INT_CASE(T) CASE(T, ColumnInt)
+#define DOUBLE_CASE(T) CASE(T, ColumnDouble)
+#define STRING_CASE(T) CASE(T, ColumnString)
   switch (data_type) {
-    case DT_STRING:
-      tensor->scalar<string>()() = stmt_.ColumnString(column_index);
-      break;
-    case DT_INT8:
-      tensor->scalar<int8>()() =
-          static_cast<int8>(stmt_.ColumnInt(column_index));
-      break;
-    case DT_INT16:
-      tensor->scalar<int16>()() =
-          static_cast<int16>(stmt_.ColumnInt(column_index));
-      break;
-    case DT_INT32:
-      tensor->scalar<int32>()() =
-          static_cast<int32>(stmt_.ColumnInt(column_index));
-      break;
-    case DT_INT64:
-      tensor->scalar<int64>()() = stmt_.ColumnInt(column_index);
-      break;
-    case DT_UINT8:
-      tensor->scalar<uint8>()() =
-          static_cast<uint8>(stmt_.ColumnInt(column_index));
-      break;
-    case DT_UINT16:
-      tensor->scalar<uint16>()() =
-          static_cast<uint16>(stmt_.ColumnInt(column_index));
-      break;
+    TF_CALL_int8(INT_CASE)
+    TF_CALL_uint8(INT_CASE)
+    TF_CALL_int16(INT_CASE)
+    TF_CALL_uint16(INT_CASE)
+    TF_CALL_int32(INT_CASE)
+    TF_CALL_uint32(INT_CASE)
+    TF_CALL_int64(INT_CASE)
+    TF_CALL_uint64(INT_CASE)
+    TF_CALL_float(DOUBLE_CASE)
+    TF_CALL_double(DOUBLE_CASE)
+    TF_CALL_string(STRING_CASE)
     case DT_BOOL:
       tensor->scalar<bool>()() = stmt_.ColumnInt(column_index) != 0;
       break;
-    case DT_FLOAT:
-      tensor->scalar<float>()() =
-          static_cast<float>(stmt_.ColumnDouble(column_index));
-      break;
-    case DT_DOUBLE:
-      tensor->scalar<double>()() = stmt_.ColumnDouble(column_index);
-      break;
       // Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
     default: {
       LOG(FATAL)
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index facac10..f469f41 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -296,19 +296,21 @@
   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
     const Tensor* target;
     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
+    AttrValueMap attr_values = func_.attr();
+    AttrValue v;
     const string& target_device =
         DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()());
+    v.set_s(target_device);
+    AddAttr("_target", v, &attr_values);
 
     FunctionLibraryRuntime* lib = ctx->function_library();
     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
                       errors::Internal("No function library is provided."),
                       done);
-    AttrValueMap attr_values = func_.attr();
     FunctionLibraryRuntime::Handle handle;
-    OP_REQUIRES_OK_ASYNC(ctx,
-                         lib->Instantiate(func_.name(), AttrSlice(&attr_values),
-                                          {target_device}, &handle),
-                         done);
+    OP_REQUIRES_OK_ASYNC(
+        ctx, lib->Instantiate(func_.name(), AttrSlice(&attr_values), &handle),
+        done);
 
     OpInputList arguments;
     OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 7ccaef9..9b4dca8 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -575,27 +575,6 @@
   bool is_training_;
 };
 
-namespace {
-
-template <typename Device>
-void FillZeros(Tensor* t);
-
-#if GOOGLE_CUDA
-template <>
-void FillZeros<GPUDevice>(Tensor* t) {
-  cudaMemset(const_cast<char*>(t->tensor_data().data()), 0,
-             t->tensor_data().size());
-}
-#endif
-
-template <>
-void FillZeros<CPUDevice>(Tensor* t) {
-  memset(const_cast<char*>(t->tensor_data().data()), 0,
-         t->tensor_data().size());
-}
-
-}  // namespace
-
 template <typename Device, typename T, typename U>
 class FusedBatchNormGradOp : public OpKernel {
  public:
@@ -659,11 +638,12 @@
     Tensor* placeholder_1 = nullptr;
     OP_REQUIRES_OK(
         context, context->allocate_output(3, TensorShape({}), &placeholder_1));
-    FillZeros<Device>(placeholder_1);
+    functor::SetZeroFunctor<Device, float> f;
+    f(context->eigen_device<Device>(), placeholder_1->flat<U>());
     Tensor* placeholder_2 = nullptr;
     OP_REQUIRES_OK(
         context, context->allocate_output(4, TensorShape({}), &placeholder_2));
-    FillZeros<Device>(placeholder_2);
+    f(context->eigen_device<Device>(), placeholder_2->flat<U>());
 
     // If input is empty, set gradients w.r.t scale/offset to zero.
     if (x.shape().num_elements() == 0) {
diff --git a/tensorflow/core/kernels/quantization_utils_test.cc b/tensorflow/core/kernels/quantization_utils_test.cc
index a73581f..d148c9f 100644
--- a/tensorflow/core/kernels/quantization_utils_test.cc
+++ b/tensorflow/core/kernels/quantization_utils_test.cc
@@ -743,7 +743,8 @@
 void TestDivide64x2Pow(int64 val, int64 ref) {
   const int64x2_t val_64x2 = vmovq_n_s64(val);
   const int64x2_t ret = Divide64x2Pow<POW>(val_64x2);
-  int64 rets[2];
+  // TODO(b/70947959) Change back to int64 when possible
+  int64_t rets[2];
   vst1q_s64(rets, ret);
   EXPECT_EQ(rets[0], ref);
   EXPECT_EQ(rets[1], ref);
@@ -754,7 +755,8 @@
 void TestDivide64x2PowRound(int64 val, int64 ref) {
   const int64x2_t val_64x2 = vmovq_n_s64(val);
   const int64x2_t shifted = Divide64x2PowRound<POW>(val_64x2);
-  int64 rets[2];
+  // TODO(b/70947959) Change back to int64 when possible
+  int64_t rets[2];
   vst1q_s64(rets, shifted);
   EXPECT_EQ(rets[0], ref) << "in = " << val << ", " << POW
                           << ", act = " << rets[0] << ", ref = " << ref;
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index f092afe..a86c046 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -67,9 +67,8 @@
     SummaryWriterInterface* s;
     auto db = Sqlite::Open(db_uri);
     OP_REQUIRES_OK(ctx, db.status());
-    db.ValueOrDie()->UseWriteAheadLogWithReducedDurabilityIfPossible();
     OP_REQUIRES_OK(
-        ctx, CreateSummaryDbWriter(std::move(db.ValueOrDie()), experiment_name,
+        ctx, CreateSummaryDbWriter(db.ConsumeValueOrDie(), experiment_name,
                                    run_name, user_name, ctx->env(), &s));
     OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
   }
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index cca6d0e..66aee2d 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -336,8 +336,7 @@
           tensor_array->HasIdenticalElementShapes(), false /* dynamic_size */,
           true /* multiple_writes_aggregate */, true /* is_grad */,
           marked_size /* marked_size */, true /* close_after_read */);
-      TF_RETURN_IF_ERROR((*ret)->CopyShapesFrom(tensor_array));
-      return Status::OK();
+      return (*ret)->CopyShapesFrom(tensor_array);
     };
 
     Status s = rm->LookupOrCreate<TensorArray>(
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 7824702..e64b27b 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -83,65 +83,100 @@
       }
     }
 
-    auto Tin = input.shaped<T, 3>(new_sizes);
-
     Tensor* idx = nullptr;
     OP_REQUIRES_OK(context, context->allocate_output(
-                                1, TensorShape({Tin.dimension(1)}), &idx));
+                                1, TensorShape({new_sizes[1]}), &idx));
     auto idx_vec = idx->template vec<TIndex>();
 
-    auto hash_fn = [&Tin](const int64& key) -> unsigned long {
-      size_t h = 0;
-      for (int64 i = 0; i < Tin.dimension(0); i++) {
-        for (int64 j = 0; j < Tin.dimension(2); j++) {
-          h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
+    int64 uniq_size;
+    if (new_sizes[0] == 1 && new_sizes[2] == 1) {
+      // Specialized and faster implementation when unique is run over single
+      // elements. Here we put T directly into the map rather than ints pointing
+      // to them as in the general case.
+      auto Tin = input.flat<T>();
+      const int64 N = static_cast<int64>(Tin.size());
+
+      std::unordered_map<T, TIndex> uniq;
+      uniq.reserve(2 * N);
+      for (int64 i = 0, j = 0; i < N; ++i) {
+        auto it = uniq.insert(std::make_pair(Tin(i), j));
+        idx_vec(i) = it.first->second;
+        if (it.second) {
+          ++j;
         }
       }
-      return h;
-    };
 
-    auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
-      for (int64 i = 0; i < Tin.dimension(0); i++) {
-        for (int64 j = 0; j < Tin.dimension(2); j++) {
-          if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
-            return false;
+      uniq_size = static_cast<int64>(uniq.size());
+      TensorShape output_shape(input.shape());
+      output_shape.set_dim(axis, uniq_size);
+      Tensor* output = nullptr;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(0, output_shape, &output));
+      auto Tout = output->flat<T>();
+
+      for (auto it : uniq) {
+        Tout(it.second) = it.first;
+      }
+    } else {
+      // General implementation when unique is run over multiple elements.
+      auto Tin = input.shaped<T, 3>(new_sizes);
+
+      auto hash_fn = [&Tin](const int64& key) {
+        size_t h = 0;
+        for (int64 i = 0; i < Tin.dimension(0); i++) {
+          for (int64 j = 0; j < Tin.dimension(2); j++) {
+            h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
           }
         }
+        return h;
+      };
+
+      auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
+        for (int64 i = 0; i < Tin.dimension(0); i++) {
+          for (int64 j = 0; j < Tin.dimension(2); j++) {
+            if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
+              return false;
+            }
+          }
+        }
+        return true;
+      };
+
+      std::unordered_map<int64, int64, decltype(hash_fn), decltype(equal_to_fn)>
+          uniq(0, hash_fn, equal_to_fn);
+
+      uniq.reserve(2 * Tin.dimension(1));
+
+      for (int64 i = 0, j = 0; i < Tin.dimension(1); ++i) {
+        auto it = uniq.insert(std::make_pair(i, j));
+        idx_vec(i) = it.first->second;
+        if (it.second) {
+          ++j;
+        }
       }
-      return true;
-    };
 
-    std::unordered_map<int64, int64, decltype(hash_fn), decltype(equal_to_fn)>
-        uniq(0, hash_fn, equal_to_fn);
+      uniq_size = static_cast<int64>(uniq.size());
+      new_sizes[1] = uniq_size;
+      TensorShape output_shape(input.shape());
+      output_shape.set_dim(axis, uniq_size);
+      Tensor* output = nullptr;
+      OP_REQUIRES_OK(context,
+                     context->allocate_output(0, output_shape, &output));
+      auto Tout = output->shaped<T, 3>(new_sizes);
 
-    uniq.reserve(2 * Tin.dimension(1));
-
-    for (int64 i = 0, j = 0; i < Tin.dimension(1); ++i) {
-      auto it = uniq.insert(std::make_pair(i, j));
-      idx_vec(i) = it.first->second;
-      if (it.second) {
-        ++j;
+      for (auto it : uniq) {
+        Tout.chip(it.second, 1) = Tin.chip(it.first, 1);
       }
     }
 
-    int64 uniq_size = static_cast<int64>(uniq.size());
-    new_sizes[1] = uniq_size;
-    TensorShape output_shape(input.shape());
-    output_shape.set_dim(axis, uniq_size);
-    Tensor* output = nullptr;
-    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
-    auto Tout = output->shaped<T, 3>(new_sizes);
-
-    for (auto it : uniq) {
-      Tout.chip(it.second, 1) = Tin.chip(it.first, 1);
-    }
-
     if (num_outputs() > 2) {
+      Tensor* output = nullptr;
       OP_REQUIRES_OK(context, context->allocate_output(
                                   2, TensorShape({uniq_size}), &output));
       auto count_output_vec = output->template vec<TIndex>();
       count_output_vec.setZero();
-      for (int64 i = 0; i < Tin.dimension(1); ++i) {
+      const int N = idx_vec.size();
+      for (int64 i = 0; i < N; ++i) {
         count_output_vec(idx_vec(i))++;
       }
     }
diff --git a/tensorflow/core/lib/db/sqlite.cc b/tensorflow/core/lib/db/sqlite.cc
index b0a9e2f..76bf778 100644
--- a/tensorflow/core/lib/db/sqlite.cc
+++ b/tensorflow/core/lib/db/sqlite.cc
@@ -14,224 +14,257 @@
 ==============================================================================*/
 #include "tensorflow/core/lib/db/sqlite.h"
 
-#include "tensorflow/core/lib/io/record_reader.h"
-#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
 
 extern "C" int sqlite3_snapfn_init(sqlite3*, const char**, const void*);
 
 namespace tensorflow {
 namespace {
 
-void ExecuteOrLog(Sqlite* db, const char* sql) {
-  Status s = db->Prepare(sql).StepAndReset();
-  if (!s.ok()) {
-    LOG(WARNING) << s.ToString();
+error::Code GetTfErrorCode(int code) {
+  // See: https://sqlite.org/rescode.html
+  switch (code & 0xff) {
+    case SQLITE_OK:    // Successful result
+    case SQLITE_ROW:   // Step has another row ready
+    case SQLITE_DONE:  // Step has finished executing
+      return error::OK;
+    case SQLITE_ABORT:  // Callback routine requested an abort
+      return error::ABORTED;
+    case SQLITE_READONLY:  // Attempt to write a readonly database
+    case SQLITE_MISMATCH:  // Data type mismatch
+      return error::FAILED_PRECONDITION;
+    case SQLITE_MISUSE:    // Library used incorrectly
+    case SQLITE_INTERNAL:  // Internal logic error in SQLite
+      return error::INTERNAL;
+    case SQLITE_RANGE:  // 2nd parameter to sqlite3_bind out of range
+      return error::OUT_OF_RANGE;
+    case SQLITE_CANTOPEN:    // Unable to open the database file
+    case SQLITE_CONSTRAINT:  // Abort due to constraint violation
+    case SQLITE_NOTFOUND:    // Unknown opcode or statement parameter name
+    case SQLITE_NOTADB:      // File opened that is not a database file
+      return error::INVALID_ARGUMENT;
+    case SQLITE_CORRUPT:  // The database disk image is malformed
+      return error::DATA_LOSS;
+    case SQLITE_AUTH:  // Authorization denied
+    case SQLITE_PERM:  // Access permission denied
+      return error::PERMISSION_DENIED;
+    case SQLITE_FULL:    // Insertion failed because database is full
+    case SQLITE_TOOBIG:  // String or BLOB exceeds size limit
+    case SQLITE_NOLFS:   // Uses OS features not supported on host
+      return error::RESOURCE_EXHAUSTED;
+    case SQLITE_BUSY:      // The database file is locked
+    case SQLITE_LOCKED:    // A table in the database is locked
+    case SQLITE_PROTOCOL:  // Database lock protocol error
+    case SQLITE_NOMEM:     // Out of heap or perhaps lookaside memory
+      return error::UNAVAILABLE;
+    case SQLITE_INTERRUPT:  // Operation terminated by sqlite3_interrupt
+      return error::CANCELLED;
+    case SQLITE_ERROR:   // SQL error or missing database
+    case SQLITE_IOERR:   // Some kind of disk I/O error occurred
+    case SQLITE_SCHEMA:  // The database schema changed
+    default:
+      return error::UNKNOWN;
   }
 }
 
-string ExecuteOrEmpty(Sqlite* db, const char* sql) {
-  auto stmt = db->Prepare(sql);
-  bool is_done = false;
-  if (stmt.Step(&is_done).ok() && !is_done) {
-    return stmt.ColumnString(0);
+template <typename... Args>
+Status PrintfStatus(int rc, const char* fmt, Args&&... args) {
+  return {GetTfErrorCode(rc),
+          strings::Printf(fmt, std::forward<Args>(args)...)};
+}
+
+Status AsStatus(Sqlite* db, int rc) EXCLUSIVE_LOCKS_REQUIRED(*db) {
+  if (TF_PREDICT_TRUE(rc == SQLITE_OK)) return Status::OK();
+  return {GetTfErrorCode(rc), db->errmsg()};
+}
+
+sqlite3_stmt* PrepareRawOrDie(sqlite3* db, const char* sql) {
+  sqlite3_stmt* stmt = nullptr;
+  int rc = sqlite3_prepare_v2(db, sql, -1, &stmt, nullptr);
+  CHECK_EQ(SQLITE_OK, rc) << sql;
+  return stmt;
+}
+
+Status SetEnvPragmaActual(Sqlite* db, const char* pragma, const char* var) {
+  const char* value = std::getenv(var);
+  if (value == nullptr || *value == '\0') return Status::OK();
+  for (const char* p = value; *p != '\0'; ++p) {
+    if (!(('0' <= *p && *p <= '9') || *p == '-' ||
+          ('A' <= *p && *p <= 'Z') ||
+          ('a' <= *p && *p <= 'z'))) {
+      return errors::InvalidArgument("Illegal character");
+    }
   }
-  return "";
+  // We can't use Bind*() for pragmas.
+  auto stmt = db->Prepare(strings::StrCat("PRAGMA ", pragma, "=", value));
+  TF_RETURN_IF_ERROR(stmt.status());
+  bool unused_done;
+  return stmt.ValueOrDie().Step(&unused_done);
+}
+
+Status EnvPragma(Sqlite* db, const char* pragma, const char* var) {
+  TF_RETURN_WITH_CONTEXT_IF_ERROR(SetEnvPragmaActual(db, pragma, var),
+                                  "getenv(", var, ")");
+  return Status::OK();
 }
 
 }  // namespace
 
 /* static */
-xla::StatusOr<std::shared_ptr<Sqlite>> Sqlite::Open(const string& uri) {
+xla::StatusOr<std::shared_ptr<Sqlite>> Sqlite::Open(string path, int flags) {
+  flags |= SQLITE_OPEN_PRIVATECACHE;
   sqlite3* sqlite = nullptr;
-  TF_RETURN_IF_ERROR(MakeStatus(sqlite3_open(uri.c_str(), &sqlite)));
-  CHECK_EQ(SQLITE_OK, sqlite3_snapfn_init(sqlite, nullptr, nullptr));
-  Sqlite* db = new Sqlite(sqlite, uri);
-  // This is the SQLite default since 2016. However it's good to set
-  // this anyway, since we might get linked against an older version of
-  // the library, and it's pretty much impossible to change later.
-  ExecuteOrLog(db, "PRAGMA page_size=4096");
-  return std::shared_ptr<Sqlite>(db);
-}
-
-/* static */ Status Sqlite::MakeStatus(int resultCode) {
-  // See: https://sqlite.org/rescode.html
-  switch (resultCode & 0xff) {
-    case SQLITE_OK:
-    case SQLITE_ROW:   // sqlite3_step() has another row ready
-    case SQLITE_DONE:  // sqlite3_step() has finished executing
-      return Status::OK();
-    case SQLITE_ABORT:  // Callback routine requested an abort
-      return errors::Aborted(sqlite3_errstr(resultCode));
-    case SQLITE_READONLY:  // Attempt to write a readonly database
-    case SQLITE_MISMATCH:  // Data type mismatch
-      return errors::FailedPrecondition(sqlite3_errstr(resultCode));
-    case SQLITE_MISUSE:    // Library used incorrectly
-    case SQLITE_INTERNAL:  // Internal logic error in SQLite
-      return errors::Internal(sqlite3_errstr(resultCode));
-    case SQLITE_RANGE:  // 2nd parameter to sqlite3_bind out of range
-      return errors::OutOfRange(sqlite3_errstr(resultCode));
-    case SQLITE_CANTOPEN:    // Unable to open the database file
-    case SQLITE_CONSTRAINT:  // Abort due to constraint violation
-    case SQLITE_NOTFOUND:    // Unknown opcode or statement parameter name
-    case SQLITE_NOTADB:      // File opened that is not a database file
-      return errors::InvalidArgument(sqlite3_errstr(resultCode));
-    case SQLITE_CORRUPT:  // The database disk image is malformed
-      return errors::DataLoss(sqlite3_errstr(resultCode));
-    case SQLITE_AUTH:  // Authorization denied
-    case SQLITE_PERM:  // Access permission denied
-      return errors::PermissionDenied(sqlite3_errstr(resultCode));
-    case SQLITE_FULL:    // Insertion failed because database is full
-    case SQLITE_TOOBIG:  // String or BLOB exceeds size limit
-    case SQLITE_NOLFS:   // Uses OS features not supported on host
-      return errors::ResourceExhausted(sqlite3_errstr(resultCode));
-    case SQLITE_BUSY:      // The database file is locked
-    case SQLITE_LOCKED:    // A table in the database is locked
-    case SQLITE_PROTOCOL:  // Database lock protocol error
-    case SQLITE_NOMEM:     // A malloc() failed
-      return errors::Unavailable(sqlite3_errstr(resultCode));
-    case SQLITE_INTERRUPT:  // Operation terminated by sqlite3_interrupt
-      return errors::Cancelled(sqlite3_errstr(resultCode));
-    case SQLITE_ERROR:   // SQL error or missing database
-    case SQLITE_IOERR:   // Some kind of disk I/O error occurred
-    case SQLITE_SCHEMA:  // The database schema changed
-    default:
-      return errors::Unknown(sqlite3_errstr(resultCode));
+  int rc = sqlite3_open_v2(path.c_str(), &sqlite, flags, nullptr);
+  if (rc != SQLITE_OK) {
+    return PrintfStatus(rc, "Sqlite::Open(%s) failed: %s", path.c_str(),
+                        sqlite3_errstr(rc));
   }
+  CHECK_EQ(SQLITE_OK, sqlite3_extended_result_codes(sqlite, 1));
+  CHECK_EQ(SQLITE_OK, sqlite3_snapfn_init(sqlite, nullptr, nullptr));
+  // Prepare these tiny privileged statements for SqliteTransaction
+  // so it can do less work, particularly in its constructor, per
+  // Google C++ Style.
+  sqlite3_stmt* begin = PrepareRawOrDie(sqlite, "BEGIN");
+  sqlite3_stmt* commit = PrepareRawOrDie(sqlite, "COMMIT");
+  sqlite3_stmt* rollback = PrepareRawOrDie(sqlite, "ROLLBACK");
+  auto r = std::shared_ptr<Sqlite>(
+      new Sqlite(sqlite, std::move(path), begin, commit, rollback));
+  r->self_ = std::weak_ptr<Sqlite>(r);
+  Sqlite* db = r.get();
+  // TensorFlow is designed to work well in all SQLite modes. However
+  // users might find tuning some these pragmas rewarding, depending on
+  // various considerations.
+  TF_RETURN_IF_ERROR(EnvPragma(db, "secure_delete", "TF_SQLITE_SECURE_DELETE"));
+  TF_RETURN_IF_ERROR(EnvPragma(db, "page_size", "TF_SQLITE_PAGE_SIZE"));
+  TF_RETURN_IF_ERROR(EnvPragma(db, "journal_mode", "TF_SQLITE_JOURNAL_MODE"));
+  TF_RETURN_IF_ERROR(EnvPragma(db, "synchronous", "TF_SQLITE_SYNCHRONOUS"));
+  TF_RETURN_IF_ERROR(EnvPragma(db, "mmap_size", "TF_SQLITE_MMAP_SIZE"));
+  TF_RETURN_IF_ERROR(EnvPragma(db, "locking_mode", "TF_SQLITE_LOCKING_MODE"));
+  TF_RETURN_IF_ERROR(EnvPragma(db, "cache_size", "TF_SQLITE_CACHE_SIZE"));
+  TF_RETURN_IF_ERROR(EnvPragma(db, "auto_vacuum", "TF_SQLITE_AUTO_VACUUM"));
+  return r;
 }
 
-Sqlite::Sqlite(sqlite3* db, const string& uri) : db_(db), uri_(uri) {}
-
 Sqlite::~Sqlite() {
-  // close_v2 doesn't care if a stmt hasn't been GC'd yet
-  int rc = sqlite3_close_v2(db_);
-  if (rc != SQLITE_OK) {
-    LOG(ERROR) << "destruct sqlite3: " << MakeStatus(rc);
-  }
+  sqlite3_finalize(rollback_);
+  sqlite3_finalize(commit_);
+  sqlite3_finalize(begin_);
+  CHECK_EQ(SQLITE_OK, sqlite3_close(db_));
 }
 
-Status Sqlite::Close() {
-  if (db_ == nullptr) {
-    return Status::OK();
-  }
-  // If Close is explicitly called, ordering must be correct.
-  Status s = MakeStatus(sqlite3_close(db_));
-  if (s.ok()) {
-    db_ = nullptr;
-  }
-  return s;
-}
-
-void Sqlite::UseWriteAheadLogWithReducedDurabilityIfPossible() {
-  // TensorFlow summaries are intensively write-heavy, cf. most apps.
-  // This pragma loves writes and means that TensorBoard can read the
-  // database even as the training job inserts stuff. In other words,
-  // this makes SQLite almost as powerful as MySQL or PostgreSQL.
-  // https://www.sqlite.org/wal.html
-  string journal = ExecuteOrEmpty(this, "PRAGMA journal_mode=wal");
-  if (journal != "wal") {
-    LOG(WARNING) << "Failed to set journal_mode=wal because SQLite wants "
-                 << uri_ << " to be in '" << journal << "' mode, which might "
-                 << "be bad since WAL is important for the performance of "
-                 << "write-intensive apps. This might only happen for memory "
-                 << "databases or old versions of SQLite, but is definitely "
-                 << "worth fixing if that's not the case";
-  } else {
-    // This setting means we might lose transactions due to power loss,
-    // but the database can't become corrupted. In exchange, we get the
-    // the performance of a NoSQL database. This is a trade-off most data
-    // scientists would consider acceptable.
-    // https://www.sqlite.org/pragma.html#pragma_synchronous
-    ExecuteOrLog(this, "PRAGMA synchronous=NORMAL");
-  }
-}
-
-SqliteStatement Sqlite::Prepare(const string& sql) {
+xla::StatusOr<SqliteStatement> Sqlite::Prepare(const StringPiece& sql) {
+  SqliteLock lock(*this);
   sqlite3_stmt* stmt = nullptr;
-  int rc = sqlite3_prepare_v2(db_, sql.c_str(), sql.size() + 1, &stmt, nullptr);
-  if (rc == SQLITE_OK) {
-    return {stmt, SQLITE_OK, std::unique_ptr<string>(nullptr)};
-  } else {
-    return {nullptr, rc, std::unique_ptr<string>(new string(sql))};
+  int rc = sqlite3_prepare_v2(db_, sql.data(), static_cast<int>(sql.size()),
+                              &stmt, nullptr);
+  if (rc != SQLITE_OK) {
+    return PrintfStatus(rc, "Prepare() failed: %s: %.*s", errmsg(), sql.size(),
+                        sql.data());
+  }
+  return SqliteStatement(stmt, self_.lock());
+}
+
+Status SqliteStatement::Step(bool* is_done) {
+  DCHECK(stmt_ != nullptr);
+  if (TF_PREDICT_FALSE(bind_error_ != SQLITE_OK)) {
+    *is_done = true;
+    return PrintfStatus(bind_error_, "Bind(%d) failed: %s: %s",
+                        bind_error_parameter_, sqlite3_errstr(bind_error_),
+                        sql());
+  }
+  SqliteLock lock(*db_);
+  int rc = sqlite3_step(stmt_);
+  switch (rc) {
+    case SQLITE_ROW:
+      *is_done = false;
+      return Status::OK();
+    case SQLITE_DONE:
+      *is_done = true;
+      return Status::OK();
+    default:
+      *is_done = true;
+      return PrintfStatus(rc, "Step() failed: %s: %s", db_->errmsg(), sql());
   }
 }
 
-Status SqliteStatement::status() const {
-  Status s = Sqlite::MakeStatus(error_);
-  if (!s.ok()) {
-    if (stmt_ != nullptr) {
-      errors::AppendToMessage(&s, sqlite3_sql(stmt_));
-    } else {
-      errors::AppendToMessage(&s, *prepare_error_sql_);
-    }
+bool SqliteStatement::StepOrDie() {
+  bool is_done;
+  TF_CHECK_OK(Step(&is_done));
+  return !is_done;
+}
+
+Status SqliteStatement::StepOnce() {
+  bool is_done;
+  TF_RETURN_IF_ERROR(Step(&is_done));
+  if (TF_PREDICT_FALSE(is_done)) {
+    return errors::Internal("No rows returned: ", sql());
   }
+  return Status::OK();
+}
+
+const SqliteStatement& SqliteStatement::StepOnceOrDie() {
+  TF_CHECK_OK(StepOnce());
+  return *this;
+}
+
+Status SqliteStatement::StepAndReset() {
+  bool is_done;
+  Status s = Step(&is_done);
+  if (TF_PREDICT_FALSE(s.ok() && !is_done)) {
+    s = errors::Internal("Unexpected row: ", sql());
+  }
+  Reset();
   return s;
 }
 
-void SqliteStatement::CloseOrLog() {
-  if (stmt_ != nullptr) {
-    int rc = sqlite3_finalize(stmt_);
-    if (rc != SQLITE_OK) {
-      LOG(ERROR) << "destruct sqlite3_stmt: " << Sqlite::MakeStatus(rc);
-    }
-    stmt_ = nullptr;
-  }
-}
-
-Status SqliteStatement::Close() {
-  if (stmt_ == nullptr) {
-    return Status::OK();
-  }
-  int rc = sqlite3_finalize(stmt_);
-  if (rc == SQLITE_OK) {
-    stmt_ = nullptr;
-  }
-  Update(rc);
-  return status();
-}
+void SqliteStatement::StepAndResetOrDie() { TF_CHECK_OK(StepAndReset()); }
 
 void SqliteStatement::Reset() {
   if (TF_PREDICT_TRUE(stmt_ != nullptr)) {
     sqlite3_reset(stmt_);
-    sqlite3_clear_bindings(stmt_);  // not nullptr friendly
+    sqlite3_clear_bindings(stmt_);
   }
-  error_ = SQLITE_OK;
+  bind_error_ = SQLITE_OK;
+  size_ = 0;
 }
 
-Status SqliteStatement::Step(bool* isDone) {
-  if (TF_PREDICT_FALSE(error_ != SQLITE_OK)) {
-    *isDone = true;
-    return status();
-  }
-  int rc = sqlite3_step(stmt_);
-  switch (rc) {
-    case SQLITE_ROW:
-      *isDone = false;
-      return Status::OK();
-    case SQLITE_DONE:
-      *isDone = true;
-      return Status::OK();
-    default:
-      *isDone = true;
-      error_ = rc;
-      return status();
+SqliteTransaction::SqliteTransaction(Sqlite& db) : db_(&db) {
+  sqlite3_mutex_enter(sqlite3_db_mutex(db_->db_));
+  CHECK(!db_->is_in_transaction_);
+  db_->is_in_transaction_ = true;
+  Begin();
+}
+
+SqliteTransaction::~SqliteTransaction() {
+  // Rollback should only return an error if there's no transaction.
+  // Since the API performs auto-rollbacks in some cases, we ignore.
+  sqlite3_step(db_->rollback_);
+  sqlite3_reset(db_->rollback_);
+  sqlite3_reset(db_->begin_);
+  db_->is_in_transaction_ = false;
+  sqlite3_mutex_leave(sqlite3_db_mutex(db_->db_));
+}
+
+void SqliteTransaction::Begin() {
+  // This shouldn't allocate memory or perform I/O. All it does is
+  // execute OP_AutoCommit(0, 0) a.k.a. BEGIN DEFERRED which flips
+  // the sqlite3::autoCommit bit.
+  if (sqlite3_step(db_->begin_) != SQLITE_DONE) {
+    // It shouldn't be possible for this to fail since we already
+    // performed the reentrancy check.
+    LOG(FATAL) << "BEGIN failed: " << sqlite3_errmsg(db_->db_);
   }
 }
 
-Status SqliteStatement::StepAndReset() {
-  if (TF_PREDICT_FALSE(error_ != SQLITE_OK)) {
-    return status();
-  }
-  Status s;
-  int rc = sqlite3_step(stmt_);
+Status SqliteTransaction::Commit() {
+  int rc = sqlite3_step(db_->commit_);
   if (rc != SQLITE_DONE) {
-    if (rc == SQLITE_ROW) {
-      s.Update(errors::Internal("unexpected sqlite row"));
-    } else {
-      s.Update(Sqlite::MakeStatus(rc));
-    }
+    return PrintfStatus(rc, "COMMIT failed: %s", sqlite3_errmsg(db_->db_));
   }
-  Reset();
-  return s;
+  sqlite3_reset(db_->commit_);
+  sqlite3_reset(db_->begin_);
+  Begin();
+  return Status::OK();
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/lib/db/sqlite.h b/tensorflow/core/lib/db/sqlite.h
index 12840bd..49a989a 100644
--- a/tensorflow/core/lib/db/sqlite.h
+++ b/tensorflow/core/lib/db/sqlite.h
@@ -17,155 +17,212 @@
 
 #include <cstddef>
 #include <memory>
+#include <mutex>
 #include <utility>
 
 #include "sqlite3.h"
 #include "tensorflow/compiler/xla/statusor.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/thread_annotations.h"
 #include "tensorflow/core/platform/types.h"
 
+/// TensorFlow SQLite Veneer
+///
+/// - Memory safety
+/// - Less boilerplate
+/// - Removes deprecated stuff
+/// - Pretends UTF16 doesn't exist
+/// - Transaction compile-time safety
+/// - Statically loads our native extensions
+/// - Error reporting via tensorflow::Status et al.
+///
+/// SQLite>=3.8.2 needs to be supported until April 2019, which is when
+/// Ubuntu 14.04 LTS becomes EOL.
+
 namespace tensorflow {
 
+class SqliteLock;
 class SqliteStatement;
+class SqliteTransaction;
 
 /// \brief SQLite connection object.
 ///
-/// This class is a thin wrapper around `sqlite3` that makes it easier
-/// and safer to use SQLite in the TensorFlow C++ codebase. It removes
-/// deprecated APIs, improves the safety of others, adds helpers, and
-/// pretends UTF16 doesn't exist.
+/// The SQLite connection is closed automatically by the destructor.
+/// Reference counting ensures that happens after its statements are
+/// destructed.
 ///
-/// Instances are thread safe, with the exception of Close().
-class Sqlite {
+/// This class offers the same thread safety behaviors and guarantees
+/// as the SQLite API itself.
+///
+/// This veneer uses auto-commit mode by default, which means a 4ms
+/// fsync() happens after every write unless a SqliteTransaction is
+/// used or WAL mode is enabled beforehand.
+class LOCKABLE Sqlite {
  public:
-  /// \brief Opens SQLite database file.
-  ///
-  /// The `uri` parameter can be a filename, or a proper URI like
-  /// `file:/tmp/tf.sqlite?mode=ro&cache=private`. It can also be
-  /// `file::memory:` for testing.
-  ///
-  /// See https://sqlite.org/c3ref/open.html
-  static xla::StatusOr<std::shared_ptr<Sqlite>> Open(const string& uri);
-
-  /// \brief Makes tensorflow::Status for SQLite result code.
-  ///
-  /// See https://sqlite.org/rescode.html
-  static Status MakeStatus(int resultCode);
-
-  /// \brief Destroys object and frees resources.
-  ///
-  /// This will free the underlying object if Close was not called. If
-  /// an error code is returned then it will be logged.
-  ///
-  /// Note: Unlike Close() this destructor maps to sqlite3_close_v2(),
-  /// which is lax about ordering and GC friendly.
+  /// \brief Closes SQLite connection, which can take milliseconds.
   ~Sqlite();
 
-  /// \brief Frees underlying SQLite object.
+  /// \brief Opens SQLite database file.
   ///
-  /// Unlike the destructor, all SqliteStatement objects must be closed
-  /// beforehand. This is a no-op if already closed
-  Status Close();
-
-  /// \brief Enables WAL mode with less fsync or log a warning.
+  /// Notes on a few of the flags:
   ///
-  /// The synchronous pragma is only set to NORMAL if WAL mode was
-  /// successfully enabled. This must be called immediately after
-  /// creating the object.
-  void UseWriteAheadLogWithReducedDurabilityIfPossible();
+  /// - SQLITE_OPEN_READONLY: Allowed if no WAL journal is active.
+  /// - SQLITE_OPEN_SHAREDCACHE: Will be ignored because this veneer
+  ///   doesn't support the unlock notify API.
+  /// - SQLITE_OPEN_NOMUTEX: Means access to this connection MUST be
+  ///   serialized by the caller in accordance with the same contracts
+  ///   implemented by this API.
+  ///
+  /// This function sets PRAGMA values from TF_SQLITE_* environment
+  /// variables. See sqlite.cc to learn more.
+  static xla::StatusOr<std::shared_ptr<Sqlite>> Open(string path, int flags);
+  static xla::StatusOr<std::shared_ptr<Sqlite>> Open(string path) {
+    return Open(std::move(path), SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE);
+  }
+  static std::shared_ptr<Sqlite> OpenOrDie(string path, int flags) {
+    return Open(std::move(path), flags).ValueOrDie();
+  }
+  static std::shared_ptr<Sqlite> OpenOrDie(string path) {
+    return Open(std::move(path)).ValueOrDie();
+  }
 
   /// \brief Creates SQLite statement.
   ///
-  /// Call result.status() to determine whether or not this operation
-  /// failed. It is also possible to punt the error checking to after
-  /// the values have been binded and Step() or ExecuteWriteQuery() is
-  /// called.
-  SqliteStatement Prepare(const string& sql);
+  /// If sql references tables then system calls taking microseconds
+  /// are needed and failure can happen on schema change. Otherwise
+  /// this should only fail on syntax error.
+  xla::StatusOr<SqliteStatement> Prepare(const StringPiece& sql);
+  SqliteStatement PrepareOrDie(const StringPiece& sql);
+
+  /// \brief Returns extended result code of last error.
+  ///
+  /// If the most recent API call was successful, the result is
+  /// undefined. The legacy result code can be obtained by saying
+  /// errcode() & 0xff.
+  int errcode() const EXCLUSIVE_LOCKS_REQUIRED(this) {
+    return sqlite3_extended_errcode(db_);
+  }
+
+  /// \brief Returns pointer to current error message state.
+  const char* errmsg() const EXCLUSIVE_LOCKS_REQUIRED(this) {
+    return sqlite3_errmsg(db_);
+  }
+
+  /// \brief Returns rowid assigned to last successful insert.
+  int64 last_insert_row_id() const EXCLUSIVE_LOCKS_REQUIRED(this) {
+    return sqlite3_last_insert_rowid(db_);
+  }
 
  private:
-  explicit Sqlite(sqlite3* db, const string& uri);
-  sqlite3* db_;
-  string uri_;
+  friend class SqliteLock;
+  friend class SqliteStatement;
+  friend class SqliteTransaction;
+
+  Sqlite(sqlite3* db, const string path, sqlite3_stmt* begin,
+         sqlite3_stmt* commit, sqlite3_stmt* rollback) noexcept
+      : db_(db),
+        path_(std::move(path)),
+        begin_(begin),
+        commit_(commit),
+        rollback_(rollback) {}
+
+  sqlite3* const db_;
+  const string path_;
+  sqlite3_stmt* const begin_;
+  sqlite3_stmt* const commit_;
+  sqlite3_stmt* const rollback_;
+  bool is_in_transaction_ = false;
+  std::weak_ptr<Sqlite> self_;  // so prepare can pass to statements
+
   TF_DISALLOW_COPY_AND_ASSIGN(Sqlite);
 };
 
-/// \brief SQLite prepared statement cursor object.
+/// \brief SQLite prepared statement.
 ///
-/// This class tracks error state internally, like Status::Update.
+/// Instances can only be shared between threads if caller serializes
+/// access from first Bind*() to *Reset().
 ///
-/// Instances of this class are not thread safe.
+/// When reusing a statement in a loop, be certain to not have jumps
+/// betwixt Bind*() and *Reset().
 class SqliteStatement {
  public:
-  /// \brief Constructs empty statement that should be assigned later.
-  SqliteStatement() : stmt_(nullptr), error_(SQLITE_OK) {}
+  /// \brief Initializes an empty statement to be assigned later.
+  SqliteStatement() noexcept = default;
 
-  /// \brief Empties object and finalizes statement if needed.
-  ~SqliteStatement() { CloseOrLog(); }
+  /// \brief Finalizes statement.
+  ///
+  /// This can take milliseconds if it was blocking the Sqlite
+  /// connection object from being freed.
+  ~SqliteStatement() { /* ignore */ sqlite3_finalize(stmt_); }
 
-  /// \brief Move constructor, after which <other> should not be used.
-  SqliteStatement(SqliteStatement&& other);
-
-  /// \brief Move assignment, after which <other> should not be used.
-  SqliteStatement& operator=(SqliteStatement&& other);
-
-  /// \brief Returns true if statement is not empty.
+  /// \brief Returns true if statement is initialized.
   explicit operator bool() const { return stmt_ != nullptr; }
 
-  /// \brief Returns SQLite result code state.
-  ///
-  /// This will be SQLITE_OK unless an error happened. If multiple
-  /// errors happened, only the first error code will be returned.
-  int error() const { return error_; }
+  /// \brief Returns SQL text from when this query was prepared.
+  const char* sql() const { return sqlite3_sql(stmt_); }
 
-  /// \brief Returns error() as a tensorflow::Status.
-  Status status() const;
+  /// \brief Number of bytes bound since last *Reset().
+  uint64 size() { return size_; }
 
-  /// \brief Finalize statement object.
+  /// \brief Executes query for fetching arbitrary rows.
   ///
-  /// Please note that the destructor can also do this. This method is
-  /// a no-op if already closed.
-  Status Close();
+  /// `is_done` will always be set to true unless SQLITE_ROW is
+  /// returned by the underlying API. If status() is already in an
+  /// error state, then this method is a no-op and the existing status
+  /// is returned.
+  ///
+  /// The OrDie version returns `!is_done` which, if true, indicates a
+  /// row is available.
+  ///
+  /// This statement should be Reset() or destructed when when finished
+  /// with the result.
+  Status Step(bool* is_done);
+  bool StepOrDie() TF_MUST_USE_RESULT;
 
-  /// \brief Executes query and/or fetches next row.
+  /// \brief Executes query when only one row is desired.
   ///
-  /// `isDone` will always be set to true unless SQLITE_ROW is returned
-  /// by the underlying API. If status() is already in an error state,
-  /// then this method is a no-op and the existing status is returned.
-  Status Step(bool* isDone);
+  /// If a row isn't returned, an internal error Status is returned
+  /// that won't be reflected in the connection error state.
+  ///
+  /// This statement should be Reset() or destructed when when finished
+  /// with the result.
+  Status StepOnce();
+  const SqliteStatement& StepOnceOrDie();
 
-  /// \brief Executes query that returns no data.
+  /// \brief Executes query, ensures zero rows returned, then Reset().
   ///
-  /// This helper calls Step(), ensures SQLITE_DONE was returned, then
-  /// resets the statement and clears the bindings. If status() is
-  /// already in an error state, then this method is a no-op and the
-  /// existing status is returned.
+  /// If a row is returned, an internal error Status is returned that
+  /// won't be reflected in the connection error state.
   Status StepAndReset();
+  void StepAndResetOrDie();
 
   /// \brief Resets statement so it can be executed again.
   ///
-  /// - Resets the prepared statement
-  /// - Sets all Bind*() values to NULL
-  ///
-  /// Support for calling sqlite3_reset() and sqlite3_clear_bindings()
-  /// independently may be added in the future if a compelling use case
-  /// can be demonstrated.
+  /// Implementation note: This method diverges from canonical API
+  /// behavior by calling sqlite3_clear_bindings() in addition to
+  /// sqlite3_reset(). That makes the veneer safer; we haven't found a
+  /// super compelling reason yet to call them independently.
   void Reset();
 
   /// \brief Binds signed 64-bit integer to 1-indexed query parameter.
   void BindInt(int parameter, int64 value) {
-    Update(sqlite3_bind_int64(stmt_, parameter, value));
+    Update(sqlite3_bind_int64(stmt_, parameter, value), parameter);
+    size_ += sizeof(int64);
   }
-  void BindInt(const string& parameter, int64 value) {
+  void BindInt(const char* parameter, int64 value) {
     BindInt(GetParameterIndex(parameter), value);
   }
 
   /// \brief Binds double to 1-indexed query parameter.
   void BindDouble(int parameter, double value) {
-    Update(sqlite3_bind_double(stmt_, parameter, value));
+    Update(sqlite3_bind_double(stmt_, parameter, value), parameter);
+    size_ += sizeof(double);
   }
-  void BindDouble(const string& parameter, double value) {
+  void BindDouble(const char* parameter, double value) {
     BindDouble(GetParameterIndex(parameter), value);
   }
 
@@ -174,69 +231,67 @@
   /// If NUL characters are present, they will still go in the DB and
   /// be successfully retrieved by ColumnString(); however, the
   /// behavior of these values with SQLite functions is undefined.
-  void BindText(int parameter, const string& text) {
+  ///
+  /// When using the unsafe methods, the data must not be changed or
+  /// freed until this statement is Reset() or finalized.
+  void BindText(int parameter, const StringPiece& text) {
     Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
-                               SQLITE_TRANSIENT, SQLITE_UTF8));
+                               SQLITE_TRANSIENT, SQLITE_UTF8), parameter);
+    size_ += text.size();
   }
-  void BindText(const string& parameter, const string& text) {
+  void BindText(const char* parameter, const StringPiece& text) {
     BindText(GetParameterIndex(parameter), text);
   }
-
-  /// \brief Copies binary data to 1-indexed query parameter.
-  void BindBlob(int parameter, const string& blob) {
-    Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
-                               SQLITE_TRANSIENT));
-  }
-  void BindBlob(const string& parameter, const string& blob) {
-    BindBlob(GetParameterIndex(parameter), blob);
-  }
-
-  /// \brief Binds UTF-8 text to 1-indexed query parameter.
-  ///
-  /// The contents of `text` must not be changed or freed until Reset()
-  /// or Close() is called.
-  ///
-  /// If NUL characters are present, they will still go in the DB and
-  /// be successfully retrieved by ColumnString(); however, the
-  /// behavior of these values with SQLite functions is undefined.
-  void BindTextUnsafe(int parameter, const string& text) {
+  void BindTextUnsafe(int parameter, const StringPiece& text) {
     Update(sqlite3_bind_text64(stmt_, parameter, text.data(), text.size(),
-                               SQLITE_STATIC, SQLITE_UTF8));
+                               SQLITE_STATIC, SQLITE_UTF8), parameter);
+    size_ += text.size();
   }
-  void BindTextUnsafe(const string& parameter, const string& text) {
+  void BindTextUnsafe(const char* parameter, const StringPiece& text) {
     BindTextUnsafe(GetParameterIndex(parameter), text);
   }
 
-  /// \brief Binds binary data to 1-indexed query parameter.
+  /// \brief Copies binary data to 1-indexed query parameter.
   ///
-  /// The contents of `blob` must not be changed or freed until Reset()
-  /// or Close() is called.
-  void BindBlobUnsafe(int parameter, const string& blob) {
+  /// When using the unsafe methods, the data must not be changed or
+  /// freed until this statement is Reset() or finalized.
+  void BindBlob(int parameter, const StringPiece& blob) {
     Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
-                               SQLITE_STATIC));
+                               SQLITE_TRANSIENT), parameter);
+    size_ += blob.size();
   }
-  void BindBlobUnsafe(const string& parameter, const string& text) {
+  void BindBlob(const char* parameter, const StringPiece& blob) {
+    BindBlob(GetParameterIndex(parameter), blob);
+  }
+  void BindBlobUnsafe(int parameter, const StringPiece& blob) {
+    Update(sqlite3_bind_blob64(stmt_, parameter, blob.data(), blob.size(),
+                               SQLITE_STATIC), parameter);
+    size_ += blob.size();
+  }
+  void BindBlobUnsafe(const char* parameter, const StringPiece& text) {
     BindBlobUnsafe(GetParameterIndex(parameter), text);
   }
 
   /// \brief Returns number of columns in result set.
-  int ColumnCount() TF_MUST_USE_RESULT { return sqlite3_column_count(stmt_); }
+  int ColumnCount() const TF_MUST_USE_RESULT {
+    return sqlite3_column_count(stmt_);
+  }
 
   /// \brief Returns type of 0-indexed column value in row data.
   ///
   /// Please note that SQLite is dynamically typed and the type of a
   /// particular column can vary from row to row.
-  int ColumnType(int column) TF_MUST_USE_RESULT {
+  int ColumnType(int column) const TF_MUST_USE_RESULT {
     return sqlite3_column_type(stmt_, column);
   }
 
   /// \brief Returns 0-indexed column from row result coerced as an integer.
-  int64 ColumnInt(int column) TF_MUST_USE_RESULT {
+  int64 ColumnInt(int column) const TF_MUST_USE_RESULT {
     return sqlite3_column_int64(stmt_, column);
   }
 
   /// \brief Returns 0-indexed column from row result coerced as a double.
-  double ColumnDouble(int column) TF_MUST_USE_RESULT {
+  double ColumnDouble(int column) const TF_MUST_USE_RESULT {
     return sqlite3_column_double(stmt_, column);
   }
 
@@ -244,80 +299,141 @@
   ///
   /// NULL values are returned as empty string. This method should be
   /// used for both BLOB and TEXT columns. See also: ColumnType().
-  string ColumnString(int column) TF_MUST_USE_RESULT {
+  string ColumnString(int column) const TF_MUST_USE_RESULT {
     auto data = sqlite3_column_blob(stmt_, column);
-    if (data == nullptr) {
-      return "";
-    }
+    if (data == nullptr) return "";
     return {static_cast<const char*>(data),
             static_cast<size_t>(ColumnSize(column))};
   }
 
   /// \brief Returns pointer to binary data at 0-indexed column.
   ///
-  /// The returned memory will be mutated or freed the next time
-  /// Step() or Reset() is called. No NUL terminator is added. See
-  /// ColumnSize(). Please note that an empty BLOB is NULL.
-  const char* ColumnStringUnsafe(int column) TF_MUST_USE_RESULT {
-    return static_cast<const char*>(sqlite3_column_blob(stmt_, column));
+  /// Empty values are returned as NULL. The returned memory will no
+  /// longer be valid the next time Step() or Reset() is called. No NUL
+  /// terminator is added.
+  StringPiece ColumnStringUnsafe(int column) const TF_MUST_USE_RESULT {
+    return {static_cast<const char*>(sqlite3_column_blob(stmt_, column)),
+            static_cast<size_t>(ColumnSize(column))};
   }
 
   /// \brief Returns number of bytes stored at 0-indexed column.
-  int ColumnSize(int column) TF_MUST_USE_RESULT {
+  int ColumnSize(int column) const TF_MUST_USE_RESULT {
     return sqlite3_column_bytes(stmt_, column);
   }
 
- private:
-  friend Sqlite;
-  SqliteStatement(sqlite3_stmt* stmt, int error,
-                  std::unique_ptr<string> prepare_error_sql)
-      : stmt_(stmt),
-        error_(error),
-        prepare_error_sql_(std::move(prepare_error_sql)) {}
-  void CloseOrLog();
+  /// \brief Move constructor, after which <other> is reset to empty.
+  SqliteStatement(SqliteStatement&& other) noexcept
+      : stmt_(other.stmt_),
+        db_(std::move(other.db_)),
+        bind_error_(other.bind_error_) {
+    other.stmt_ = nullptr;
+    other.bind_error_ = SQLITE_OK;
+  }
 
-  void Update(int rc) {
+  /// \brief Move assignment, after which <other> is reset to empty.
+  SqliteStatement& operator=(SqliteStatement&& other) noexcept {
+    if (&other != this) {
+      sqlite3_finalize(stmt_);
+      stmt_ = other.stmt_;
+      bind_error_ = other.bind_error_;
+      db_ = std::move(other.db_);
+      size_ = 0;
+      other.stmt_ = nullptr;
+      other.bind_error_ = SQLITE_OK;
+    }
+    return *this;
+  }
+
+ private:
+  friend class Sqlite;
+
+  SqliteStatement(sqlite3_stmt* stmt, std::shared_ptr<Sqlite> db) noexcept
+      : stmt_(stmt), db_(std::move(db)) {}
+
+  void Update(int rc, int parameter) {
+    // Binding strings can fail if they exceed length limit.
     if (TF_PREDICT_FALSE(rc != SQLITE_OK)) {
-      if (error_ == SQLITE_OK) {
-        error_ = rc;
+      if (bind_error_ == SQLITE_OK) {
+        bind_error_ = rc;
+        bind_error_parameter_ = parameter;
       }
     }
   }
 
-  int GetParameterIndex(const string& parameter) {
-    // Each call to this function requires O(n) strncmp().
-    int index = sqlite3_bind_parameter_index(stmt_, parameter.c_str());
-    if (TF_PREDICT_FALSE(index == 0)) {
-      Update(SQLITE_NOTFOUND);
-    }
+  int GetParameterIndex(const char* parameter) {
+    int index = sqlite3_bind_parameter_index(stmt_, parameter);
+    DCHECK(index > 0);  // OK to compile away since it'll fail again
     return index;
   }
 
-  sqlite3_stmt* stmt_;
-  int error_;
-  std::unique_ptr<string> prepare_error_sql_;
+  sqlite3_stmt* stmt_ = nullptr;
+  std::shared_ptr<Sqlite> db_;
+  int bind_error_ = SQLITE_OK;
+  int bind_error_parameter_ = 0;
+  uint64 size_ = 0;
 
   TF_DISALLOW_COPY_AND_ASSIGN(SqliteStatement);
 };
 
-inline SqliteStatement::SqliteStatement(SqliteStatement&& other)
-    : stmt_(other.stmt_),
-      error_(other.error_),
-      prepare_error_sql_(std::move(other.prepare_error_sql_)) {
-  other.stmt_ = nullptr;
-  other.error_ = SQLITE_OK;
-}
-
-inline SqliteStatement& SqliteStatement::operator=(SqliteStatement&& other) {
-  if (&other != this) {
-    CloseOrLog();
-    stmt_ = other.stmt_;
-    error_ = other.error_;
-    prepare_error_sql_ = std::move(other.prepare_error_sql_);
-    other.stmt_ = nullptr;
-    other.error_ = SQLITE_OK;
+/// \brief Reentrant SQLite connection object lock
+///
+/// This is a no-op if SQLITE_OPEN_NOMUTEX was used.
+class SCOPED_LOCKABLE SqliteLock {
+ public:
+  explicit SqliteLock(Sqlite& db) EXCLUSIVE_LOCK_FUNCTION(db)
+      : mutex_(sqlite3_db_mutex(db.db_)) {
+    sqlite3_mutex_enter(mutex_);
   }
-  return *this;
+  SqliteLock(Sqlite& db, std::try_to_lock_t) EXCLUSIVE_LOCK_FUNCTION(db)
+      : mutex_(sqlite3_db_mutex(db.db_)) {
+    if (TF_PREDICT_FALSE(sqlite3_mutex_try(mutex_) != SQLITE_OK)) {
+      is_locked_ = false;
+    }
+  }
+  ~SqliteLock() UNLOCK_FUNCTION() {
+    if (is_locked_) sqlite3_mutex_leave(mutex_);
+  }
+  explicit operator bool() const { return is_locked_; }
+
+ private:
+  sqlite3_mutex* const mutex_;
+  bool is_locked_ = true;
+  TF_DISALLOW_COPY_AND_ASSIGN(SqliteLock);
+};
+#define SqliteLock(x) static_assert(0, "sqlite_lock_decl_missing_name");
+
+/// \brief SQLite transaction scope.
+///
+/// This class acquires an exclusive lock on the connection object (if
+/// mutexes weren't disabled) and runs BEGIN / ROLLBACK automatically.
+/// Unlike SqliteLock this scope is non-reentrant. To avoid program
+/// crashes, business logic should use the EXCLUSIVE_LOCK_FUNCTION and
+/// LOCKS_EXCLUDED annotations as much as possible.
+class SCOPED_LOCKABLE SqliteTransaction {
+ public:
+  /// \brief Locks db and begins deferred transaction.
+  ///
+  /// This will crash if a transaction is already active.
+  explicit SqliteTransaction(Sqlite& db) EXCLUSIVE_LOCK_FUNCTION(db);
+
+  /// \brief Runs ROLLBACK and unlocks.
+  ~SqliteTransaction() UNLOCK_FUNCTION();
+
+  /// \brief Commits transaction.
+  ///
+  /// If this is successful, a new transaction will be started, which
+  /// is rolled back when exiting the scope.
+  Status Commit();
+
+ private:
+  void Begin();
+  Sqlite* const db_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(SqliteTransaction);
+};
+
+inline SqliteStatement Sqlite::PrepareOrDie(const StringPiece& sql) {
+  return Prepare(sql).ValueOrDie();
 }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc
index 29772b8..f93b3d8 100644
--- a/tensorflow/core/lib/db/sqlite_test.cc
+++ b/tensorflow/core/lib/db/sqlite_test.cc
@@ -14,13 +14,13 @@
 ==============================================================================*/
 #include "tensorflow/core/lib/db/sqlite.h"
 
-#include <limits.h>
 #include <array>
+#include <climits>
 
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/io/path.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace tensorflow {
@@ -29,23 +29,22 @@
 class SqliteTest : public ::testing::Test {
  protected:
   void SetUp() override {
-    db_ = Sqlite::Open(":memory:").ValueOrDie();
-    auto stmt = db_->Prepare("CREATE TABLE T (a BLOB, b BLOB)");
-    TF_ASSERT_OK(stmt.StepAndReset());
+    db_ = Sqlite::OpenOrDie(":memory:");
+    db_->PrepareOrDie("CREATE TABLE T (a BLOB, b BLOB)").StepAndResetOrDie();
   }
   std::shared_ptr<Sqlite> db_;
   bool is_done_;
 };
 
 TEST_F(SqliteTest, InsertAndSelectInt) {
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
   stmt.BindInt(1, 3);
   stmt.BindInt(2, -7);
   TF_ASSERT_OK(stmt.StepAndReset());
   stmt.BindInt(1, 123);
   stmt.BindInt(2, -123);
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT a, b FROM T ORDER BY b");
+  stmt = db_->PrepareOrDie("SELECT a, b FROM T ORDER BY b");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   ASSERT_FALSE(is_done_);
   EXPECT_EQ(123, stmt.ColumnInt(0));
@@ -59,11 +58,11 @@
 }
 
 TEST_F(SqliteTest, InsertAndSelectDouble) {
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
   stmt.BindDouble(1, 6.28318530);
   stmt.BindDouble(2, 1.61803399);
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT a, b FROM T");
+  stmt = db_->PrepareOrDie("SELECT a, b FROM T");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_EQ(6.28318530, stmt.ColumnDouble(0));
   EXPECT_EQ(1.61803399, stmt.ColumnDouble(1));
@@ -74,11 +73,11 @@
 TEST_F(SqliteTest, NulCharsInString) {
   string s;  // XXX: Want to write {2, '\0'} but not sure why not.
   s.append(static_cast<size_t>(2), '\0');
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
   stmt.BindBlob(1, s);
   stmt.BindText(2, s);
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT a, b FROM T");
+  stmt = db_->PrepareOrDie("SELECT a, b FROM T");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_EQ(2, stmt.ColumnSize(0));
   EXPECT_EQ(2, stmt.ColumnString(0).size());
@@ -92,58 +91,38 @@
 
 TEST_F(SqliteTest, Unicode) {
   string s = "要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非";
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
   stmt.BindBlob(1, s);
   stmt.BindText(2, s);
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT a, b FROM T");
+  stmt = db_->PrepareOrDie("SELECT a, b FROM T");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_EQ(s, stmt.ColumnString(0));
   EXPECT_EQ(s, stmt.ColumnString(1));
 }
 
 TEST_F(SqliteTest, StepAndResetClearsBindings) {
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
   stmt.BindInt(1, 1);
   stmt.BindInt(2, 123);
   TF_ASSERT_OK(stmt.StepAndReset());
   stmt.BindInt(1, 2);
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT b FROM T ORDER BY a");
+  stmt = db_->PrepareOrDie("SELECT b FROM T ORDER BY a");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_EQ(123, stmt.ColumnInt(0));
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_EQ(SQLITE_NULL, stmt.ColumnType(0));
 }
 
-TEST_F(SqliteTest, CloseBeforeFinalizeFails) {
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
-  Status s = db_->Close();
-  EXPECT_FALSE(s.ok());
-}
-
-// Rather than bothering to check the status code of creating a
-// statement and every single bind call afterwards, SqliteStatement
-// is designed to carry the first error state forward to Step().
-TEST_F(SqliteTest, ErrorPuntingDoesNotReportLibraryAbuse) {
-  auto stmt = db_->Prepare("lol cat");
-  EXPECT_FALSE(stmt.status().ok());
-  EXPECT_EQ(SQLITE_ERROR, stmt.error());
-  stmt.BindInt(1, 1);
-  stmt.BindInt(2, 2);
-  Status s = stmt.Step(&is_done_);
-  EXPECT_EQ(SQLITE_ERROR, stmt.error());  // first error of several
-  EXPECT_FALSE(s.ok());
-}
-
 TEST_F(SqliteTest, SafeBind) {
   string s = "hello";
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
   stmt.BindBlob(1, s);
   stmt.BindText(2, s);
   s.at(0) = 'y';
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT a, b FROM T");
+  stmt = db_->PrepareOrDie("SELECT a, b FROM T");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_EQ("hello", stmt.ColumnString(0));
   EXPECT_EQ("hello", stmt.ColumnString(1));
@@ -151,42 +130,42 @@
 
 TEST_F(SqliteTest, UnsafeBind) {
   string s = "hello";
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
   stmt.BindBlobUnsafe(1, s);
   stmt.BindTextUnsafe(2, s);
   s.at(0) = 'y';
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT a, b FROM T");
+  stmt = db_->PrepareOrDie("SELECT a, b FROM T");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_EQ("yello", stmt.ColumnString(0));
   EXPECT_EQ("yello", stmt.ColumnString(1));
 }
 
 TEST_F(SqliteTest, UnsafeColumn) {
-  auto stmt = db_->Prepare("INSERT INTO T (a, b) VALUES (?, ?)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
   stmt.BindInt(1, 1);
   stmt.BindText(2, "hello");
   TF_ASSERT_OK(stmt.StepAndReset());
   stmt.BindInt(1, 2);
   stmt.BindText(2, "there");
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT b FROM T ORDER BY a");
+  stmt = db_->PrepareOrDie("SELECT b FROM T ORDER BY a");
   TF_ASSERT_OK(stmt.Step(&is_done_));
-  const char* p = stmt.ColumnStringUnsafe(0);
-  EXPECT_EQ('h', *p);
+  StringPiece p = stmt.ColumnStringUnsafe(0);
+  EXPECT_EQ('h', *p.data());
   TF_ASSERT_OK(stmt.Step(&is_done_));
   // This will actually happen, but it's not safe to test this behavior.
-  // EXPECT_EQ('t', *p);
+  // EXPECT_EQ('t', *p.data());
 }
 
 TEST_F(SqliteTest, NamedParameterBind) {
-  auto stmt = db_->Prepare("INSERT INTO T (a) VALUES (:a)");
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a) VALUES (:a)");
   stmt.BindText(":a", "lol");
   TF_ASSERT_OK(stmt.StepAndReset());
-  stmt = db_->Prepare("SELECT COUNT(*) FROM T");
+  stmt = db_->PrepareOrDie("SELECT COUNT(*) FROM T");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_EQ(1, stmt.ColumnInt(0));
-  stmt = db_->Prepare("SELECT a FROM T");
+  stmt = db_->PrepareOrDie("SELECT a FROM T");
   TF_ASSERT_OK(stmt.Step(&is_done_));
   EXPECT_FALSE(is_done_);
   EXPECT_EQ("lol", stmt.ColumnString(0));
@@ -195,57 +174,107 @@
 TEST_F(SqliteTest, Statement_DefaultConstructor) {
   SqliteStatement stmt;
   EXPECT_FALSE(stmt);
-  EXPECT_FALSE(stmt.StepAndReset().ok());
-  stmt = db_->Prepare("INSERT INTO T (a) VALUES (1)");
+  stmt = db_->PrepareOrDie("INSERT INTO T (a) VALUES (1)");
   EXPECT_TRUE(stmt);
   EXPECT_TRUE(stmt.StepAndReset().ok());
 }
 
 TEST_F(SqliteTest, Statement_MoveConstructor) {
-  SqliteStatement stmt{db_->Prepare("INSERT INTO T (a) VALUES (1)")};
+  SqliteStatement stmt{db_->PrepareOrDie("INSERT INTO T (a) VALUES (1)")};
   EXPECT_TRUE(stmt.StepAndReset().ok());
 }
 
 TEST_F(SqliteTest, Statement_MoveAssignment) {
-  SqliteStatement stmt1 = db_->Prepare("INSERT INTO T (a) VALUES (1)");
+  SqliteStatement stmt1 = db_->PrepareOrDie("INSERT INTO T (a) VALUES (1)");
   SqliteStatement stmt2;
   EXPECT_TRUE(stmt1.StepAndReset().ok());
-  EXPECT_FALSE(stmt2.StepAndReset().ok());
+  EXPECT_FALSE(stmt2);
   stmt2 = std::move(stmt1);
   EXPECT_TRUE(stmt2.StepAndReset().ok());
 }
 
 TEST_F(SqliteTest, PrepareFailed) {
-  SqliteStatement s = db_->Prepare("SELECT");
-  EXPECT_FALSE(s.status().ok());
-  EXPECT_NE(string::npos, s.status().error_message().find("SELECT"));
+  SqliteLock lock(*db_);
+  Status s = db_->Prepare("SELECT").status();
+  ASSERT_FALSE(s.ok());
+  EXPECT_NE(string::npos, s.error_message().find("SELECT"));
+  EXPECT_EQ(SQLITE_ERROR, db_->errcode());
 }
 
 TEST_F(SqliteTest, BindFailed) {
-  SqliteStatement s = db_->Prepare("INSERT INTO T (a) VALUES (123)");
-  EXPECT_TRUE(s.status().ok());
-  EXPECT_EQ("", s.status().error_message());
-  s.BindInt(1, 123);
-  EXPECT_FALSE(s.status().ok());
+  auto stmt = db_->PrepareOrDie("INSERT INTO T (a) VALUES (123)");
+  stmt.BindInt(1, 123);
+  Status s = stmt.StepOnce();
   EXPECT_NE(string::npos,
-            s.status().error_message().find("INSERT INTO T (a) VALUES (123)"));
+            s.error_message().find("INSERT INTO T (a) VALUES (123)"))
+      << s.error_message();
 }
 
 TEST_F(SqliteTest, SnappyExtension) {
-  auto stmt = db_->Prepare("SELECT UNSNAP(SNAP(?))");
+  auto stmt = db_->PrepareOrDie("SELECT UNSNAP(SNAP(?))");
   stmt.BindText(1, "hello");
-  TF_ASSERT_OK(stmt.Step(&is_done_));
-  EXPECT_FALSE(is_done_);
-  EXPECT_EQ("hello", stmt.ColumnString(0));
+  EXPECT_EQ("hello", stmt.StepOnceOrDie().ColumnString(0));
 }
 
 TEST_F(SqliteTest, SnappyBinaryCompatibility) {
-  auto stmt = db_->Prepare(
-      "SELECT UNSNAP(X'03207C746F6461792069732074686520656E64206F66207468652"
-      "072657075626C6963')");
-  TF_ASSERT_OK(stmt.Step(&is_done_));
-  EXPECT_FALSE(is_done_);
-  EXPECT_EQ("today is the end of the republic", stmt.ColumnString(0));
+  EXPECT_EQ(
+      "today is the end of the republic",
+      db_->PrepareOrDie("SELECT UNSNAP(X'03207C746F6461792069732074686520656E64"
+                        "206F66207468652072657075626C6963')")
+          .StepOnceOrDie()
+          .ColumnString(0));
+}
+
+TEST(SqliteOpenTest, CloseConnectionBeforeStatement_KeepsConnectionOpen) {
+  auto s = Sqlite::OpenOrDie(":memory:")->PrepareOrDie("SELECT ? + ?");
+  s.BindInt(1, 7);
+  s.BindInt(2, 3);
+  EXPECT_EQ(10, s.StepOnceOrDie().ColumnInt(0));
+}
+
+TEST_F(SqliteTest, TransactionRollback) {
+  {
+    SqliteTransaction txn(*db_);
+    auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
+    stmt.BindDouble(1, 6.28318530);
+    stmt.BindDouble(2, 1.61803399);
+    TF_ASSERT_OK(stmt.StepAndReset());
+  }
+  EXPECT_EQ(
+      0,
+      db_->PrepareOrDie("SELECT COUNT(*) FROM T").StepOnceOrDie().ColumnInt(0));
+}
+
+TEST_F(SqliteTest, TransactionCommit) {
+  {
+    SqliteTransaction txn(*db_);
+    auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
+    stmt.BindDouble(1, 6.28318530);
+    stmt.BindDouble(2, 1.61803399);
+    TF_ASSERT_OK(stmt.StepAndReset());
+    TF_ASSERT_OK(txn.Commit());
+  }
+  EXPECT_EQ(
+      1,
+      db_->PrepareOrDie("SELECT COUNT(*) FROM T").StepOnceOrDie().ColumnInt(0));
+}
+
+TEST_F(SqliteTest, TransactionCommitMultipleTimes) {
+  {
+    SqliteTransaction txn(*db_);
+    auto stmt = db_->PrepareOrDie("INSERT INTO T (a, b) VALUES (?, ?)");
+    stmt.BindDouble(1, 6.28318530);
+    stmt.BindDouble(2, 1.61803399);
+    TF_ASSERT_OK(stmt.StepAndReset());
+    TF_ASSERT_OK(txn.Commit());
+    stmt.BindDouble(1, 6.28318530);
+    stmt.BindDouble(2, 1.61803399);
+    TF_ASSERT_OK(stmt.StepAndReset());
+    TF_ASSERT_OK(txn.Commit());
+  }
+  EXPECT_EQ(
+      2,
+      db_->PrepareOrDie("SELECT COUNT(*) FROM T").StepOnceOrDie().ColumnInt(0));
 }
 
 }  // namespace
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 9329749..ddc4910 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1413,6 +1413,10 @@
       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
       c->set_output(0, c->Vector(2));
       c->set_output(1, c->Scalar());
+      if (c->input_handle_shapes_and_types(0)) {
+        c->set_output_handle_shapes_and_types(
+            0, *c->input_handle_shapes_and_types(0));
+      }
       return Status::OK();
     })
     .Doc(R"doc(
diff --git a/tensorflow/core/platform/cloud/expiring_lru_cache.h b/tensorflow/core/platform/cloud/expiring_lru_cache.h
index 3fc23a4..c738497 100644
--- a/tensorflow/core/platform/cloud/expiring_lru_cache.h
+++ b/tensorflow/core/platform/cloud/expiring_lru_cache.h
@@ -88,6 +88,13 @@
     return s;
   }
 
+  /// Clear the cache.
+  void Clear() {
+    mutex_lock lock(mu_);
+    cache_.clear();
+    lru_list_.clear();
+  }
+
   /// Accessors for cache parameters.
   uint64 max_age() const { return max_age_; }
   size_t max_entries() const { return max_entries_; }
diff --git a/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc b/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
index 8f8d574..3bc6db3 100644
--- a/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
+++ b/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
@@ -152,5 +152,27 @@
   EXPECT_EQ(num_compute_calls, 6);
 }
 
+TEST(ExpiringLRUCacheTest, Clear) {
+  ExpiringLRUCache<int> cache(1, 4);
+  cache.Insert("a", 1);
+  cache.Insert("b", 2);
+  cache.Insert("c", 3);
+  cache.Insert("d", 4);
+  int value = 0;
+  EXPECT_TRUE(cache.Lookup("a", &value));
+  EXPECT_EQ(value, 1);
+  EXPECT_TRUE(cache.Lookup("b", &value));
+  EXPECT_EQ(value, 2);
+  EXPECT_TRUE(cache.Lookup("c", &value));
+  EXPECT_EQ(value, 3);
+  EXPECT_TRUE(cache.Lookup("d", &value));
+  EXPECT_EQ(value, 4);
+  cache.Clear();
+  EXPECT_FALSE(cache.Lookup("a", &value));
+  EXPECT_FALSE(cache.Lookup("b", &value));
+  EXPECT_FALSE(cache.Lookup("c", &value));
+  EXPECT_FALSE(cache.Lookup("d", &value));
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc
index 6831600..0375af5 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.cc
+++ b/tensorflow/core/platform/cloud/file_block_cache.cc
@@ -237,6 +237,14 @@
   }
 }
 
+void FileBlockCache::Flush() {
+  mutex_lock lock(mu_);
+  block_map_.clear();
+  lru_list_.clear();
+  lra_list_.clear();
+  cache_size_ = 0;
+}
+
 void FileBlockCache::RemoveFile(const string& filename) {
   mutex_lock lock(mu_);
   RemoveFile_Locked(filename);
diff --git a/tensorflow/core/platform/cloud/file_block_cache.h b/tensorflow/core/platform/cloud/file_block_cache.h
index 74e792a..5c180e2 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.h
+++ b/tensorflow/core/platform/cloud/file_block_cache.h
@@ -90,6 +90,9 @@
   /// Remove all cached blocks for `filename`.
   void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_);
 
+  /// Remove all cached data.
+  void Flush() LOCKS_EXCLUDED(mu_);
+
   /// Accessors for cache parameters.
   size_t block_size() const { return block_size_; }
   size_t max_bytes() const { return max_bytes_; }
diff --git a/tensorflow/core/platform/cloud/file_block_cache_test.cc b/tensorflow/core/platform/cloud/file_block_cache_test.cc
index ae87e0d..596fdbf 100644
--- a/tensorflow/core/platform/cloud/file_block_cache_test.cc
+++ b/tensorflow/core/platform/cloud/file_block_cache_test.cc
@@ -495,5 +495,25 @@
 
   EXPECT_EQ(1, num_requests);
 }
+
+TEST(FileBlockCacheTest, Flush) {
+  int calls = 0;
+  auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
+                          char* buffer, size_t* bytes_transferred) {
+    calls++;
+    memset(buffer, 'x', n);
+    *bytes_transferred = n;
+    return Status::OK();
+  };
+  FileBlockCache cache(16, 32, 0, fetcher);
+  std::vector<char> out;
+  TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
+  TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
+  EXPECT_EQ(calls, 1);
+  cache.Flush();
+  TF_EXPECT_OK(ReadCache(&cache, "", 0, 16, &out));
+  EXPECT_EQ(calls, 2);
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index dffd1c4..970a6b1 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -1377,6 +1377,15 @@
   return Status::OK();
 }
 
+// Flushes all caches for filesystem metadata and file contents. Useful for
+// reclaiming memory once filesystem operations are done (e.g. model is loaded),
+// or for resetting the filesystem to a consistent state.
+void GcsFileSystem::FlushCaches() {
+  file_block_cache_->Flush();
+  stat_cache_->Clear();
+  matching_paths_cache_->Clear();
+}
+
 // Creates an HttpRequest and sets several parameters that are common to all
 // requests.  All code (in GcsFileSystem) that creates an HttpRequest should
 // go through this method, rather than directly using http_request_factory_.
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 731f97a..adde161 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -84,6 +84,8 @@
   Status DeleteRecursively(const string& dirname, int64* undeleted_files,
                            int64* undeleted_dirs) override;
 
+  void FlushCaches() override;
+
   /// These accessors are mainly for testing purposes, to verify that the
   /// environment variables that control these parameters are handled correctly.
   size_t block_size() const { return file_block_cache_->block_size(); }
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 32bd946..772aec5 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -195,6 +195,49 @@
   EXPECT_EQ("0123", result);
 }
 
+TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) {
+  // Our underlying file in this test is a 15 byte file with contents
+  // "0123456789abcde".
+  std::vector<HttpRequest*> requests(
+      {new FakeHttpRequest(
+           "Uri: https://storage.googleapis.com/bucket/random_access.txt\n"
+           "Auth Token: fake_token\n"
+           "Range: 0-8\n"
+           "Timeouts: 5 1 20\n",
+           "012345678"),
+       new FakeHttpRequest(
+           "Uri: https://storage.googleapis.com/bucket/random_access.txt\n"
+           "Auth Token: fake_token\n"
+           "Range: 0-8\n"
+           "Timeouts: 5 1 20\n",
+           "012345678")});
+  GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+                   std::unique_ptr<HttpRequest::Factory>(
+                       new FakeHttpRequestFactory(&requests)),
+                   9 /* block size */, 18 /* max bytes */,
+                   0 /* max staleness */, 0 /* stat cache max age */,
+                   0 /* stat cache max entries */,
+                   0 /* matching paths cache max age */,
+                   0 /* matching paths cache max entries */,
+                   0 /* initial retry delay */, kTestTimeoutConfig);
+
+  char scratch[100];
+  StringPiece result;
+  std::unique_ptr<RandomAccessFile> file;
+  TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
+  // Read the first chunk. The cache will be populated with the first block of
+  // 9 bytes.
+  scratch[5] = 'x';
+  TF_EXPECT_OK(file->Read(0, 4, &result, scratch));
+  EXPECT_EQ("0123", result);
+  EXPECT_EQ(scratch[5], 'x');  // Make sure we only copied 4 bytes.
+  // Flush caches and read the second chunk. This will be a cache miss, and
+  // the same block will be fetched again.
+  fs.FlushCaches();
+  TF_EXPECT_OK(file->Read(4, 4, &result, scratch));
+  EXPECT_EQ("4567", result);
+}
+
 TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) {
   // Our underlying file in this test is a 16 byte file with contents
   // "0123456789abcdef".
@@ -1270,6 +1313,50 @@
   }
 }
 
+TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) {
+  std::vector<HttpRequest*> requests(
+      {new FakeHttpRequest(
+           "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+           "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n"
+           "Auth Token: fake_token\n"
+           "Timeouts: 5 1 10\n",
+           "{\"items\": [ "
+           "  { \"name\": \"path/subpath/file2.txt\" }]}"),
+       new FakeHttpRequest(
+           "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+           "fields=items%2Fname%2CnextPageToken&prefix=path%2Fsubpath%2F\n"
+           "Auth Token: fake_token\n"
+           "Timeouts: 5 1 10\n",
+           "{\"items\": [ "
+           "  { \"name\": \"path/subpath/file2.txt\" }]}")});
+  GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+                   std::unique_ptr<HttpRequest::Factory>(
+                       new FakeHttpRequestFactory(&requests)),
+                   0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+                   0 /* stat cache max age */, 0 /* stat cache max entries */,
+                   3600 /* matching paths cache max age */,
+                   0 /* matching paths cache max entries */,
+                   0 /* initial retry delay*/, kTestTimeoutConfig);
+
+  // This loop should trigger the first HTTP request to GCS.
+  for (int i = 0; i < 10; i++) {
+    std::vector<string> result;
+    TF_EXPECT_OK(
+        fs.GetMatchingPaths("gs://bucket/path/subpath/file2.txt", &result));
+    EXPECT_EQ(std::vector<string>({"gs://bucket/path/subpath/file2.txt"}),
+              result);
+  }
+  // After flushing caches, there should be another (identical) request to GCS.
+  fs.FlushCaches();
+  for (int i = 0; i < 10; i++) {
+    std::vector<string> result;
+    TF_EXPECT_OK(
+        fs.GetMatchingPaths("gs://bucket/path/subpath/file2.txt", &result));
+    EXPECT_EQ(std::vector<string>({"gs://bucket/path/subpath/file2.txt"}),
+              result);
+  }
+}
+
 TEST(GcsFileSystemTest, DeleteFile) {
   std::vector<HttpRequest*> requests(
       {new FakeHttpRequest(
@@ -1895,6 +1982,50 @@
   }
 }
 
+TEST(GcsFileSystemTest, Stat_Cache_Flush) {
+  std::vector<HttpRequest*> requests(
+      {new FakeHttpRequest(
+           "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+           "file.txt?fields=size%2Cupdated\n"
+           "Auth Token: fake_token\n"
+           "Timeouts: 5 1 10\n",
+           strings::StrCat("{\"size\": \"1010\","
+                           "\"updated\": \"2016-04-29T23:15:24.896Z\"}")),
+       new FakeHttpRequest(
+           "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+           "file.txt?fields=size%2Cupdated\n"
+           "Auth Token: fake_token\n"
+           "Timeouts: 5 1 10\n",
+           strings::StrCat("{\"size\": \"1010\","
+                           "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
+  GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+                   std::unique_ptr<HttpRequest::Factory>(
+                       new FakeHttpRequestFactory(&requests)),
+                   0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+                   3600 /* stat cache max age */,
+                   0 /* stat cache max entries */,
+                   0 /* matching paths cache max age */,
+                   0 /* matching paths cache max entries */,
+                   0 /* initial retry delay*/, kTestTimeoutConfig);
+  // There should be a single HTTP request to GCS for fs.Stat in this loop.
+  for (int i = 0; i < 10; i++) {
+    FileStatistics stat;
+    TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
+    EXPECT_EQ(1010, stat.length);
+    EXPECT_NEAR(1461971724896, stat.mtime_nsec / 1000 / 1000, 1);
+    EXPECT_FALSE(stat.is_directory);
+  }
+  // After flushing caches, there should be a second request to GCS for fs.Stat.
+  fs.FlushCaches();
+  for (int i = 0; i < 10; i++) {
+    FileStatistics stat;
+    TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
+    EXPECT_EQ(1010, stat.length);
+    EXPECT_NEAR(1461971724896, stat.mtime_nsec / 1000 / 1000, 1);
+    EXPECT_FALSE(stat.is_directory);
+  }
+}
+
 TEST(GcsFileSystemTest, IsDirectory_NotFound) {
   std::vector<HttpRequest*> requests(
       {new FakeHttpRequest(
diff --git a/tensorflow/core/platform/default/stacktrace.h b/tensorflow/core/platform/default/stacktrace.h
index 5f30732..436716d 100644
--- a/tensorflow/core/platform/default/stacktrace.h
+++ b/tensorflow/core/platform/default/stacktrace.h
@@ -17,12 +17,61 @@
 #define TENSORFLOW_CORE_PLATFORM_DEFAULT_STACKTRACE_H_
 
 #include "tensorflow/core/platform/platform.h"
+#if !defined(IS_MOBILE_PLATFORM) && defined(PLATFORM_POSIX) && \
+    (defined(__clang__) || defined(__GNUC__))
+#define TF_GENERATE_BACKTRACE
+#endif
+
+#if defined(TF_GENERATE_BACKTRACE)
+#include <dlfcn.h>
+#include <execinfo.h>
+#include <stdio.h>
+#include <string.h>
+#include <unistd.h>
+#endif  // defined(TF_GENERATE_BACKTRACE)
+
+#include <sstream>
+#include <string>
+#include "tensorflow/core/platform/abi.h"
 
 namespace tensorflow {
 
-inline string CurrentStackTrace() { return "No stack trace available"; }
+// Function to create a pretty stacktrace.
+inline std::string CurrentStackTrace() {
+#if defined(TF_GENERATE_BACKTRACE)
+  std::stringstream ss("");
+  ss << "*** Begin stack trace ***" << std::endl;
 
-inline void DebugWriteToString(const char* data, void* arg) {}
+  // Get the mangled stack trace.
+  int buffer_size = 128;
+  void* trace[128];
+  buffer_size = backtrace(trace, buffer_size);
+
+  for (int i = 0; i < buffer_size; ++i) {
+    const char* symbol = "";
+    Dl_info info;
+    if (dladdr(trace[i], &info)) {
+      if (info.dli_sname != nullptr) {
+        symbol = info.dli_sname;
+      }
+    }
+
+    std::string demangled = tensorflow::port::MaybeAbiDemangle(symbol);
+    if (demangled.length()) {
+      ss << "\t" << demangled << std::endl;
+    } else {
+      ss << "\t" << symbol << std::endl;
+    }
+  }
+
+  ss << "*** End stack trace ***" << std::endl;
+  return ss.str();
+#endif  // defined(TF_GENERATE_BACKTRACE)
+}
+
+inline void DebugWriteToString(const char* data, void* arg) {
+  reinterpret_cast<std::string*>(arg)->append(data);
+}
 
 // A dummy class that does nothing.  Someday, add real support.
 class SavedStackTrace {
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 816ffa4..d617ff1 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -108,6 +108,18 @@
   return file_system_registry_->Register(scheme, std::move(factory));
 }
 
+Status Env::FlushFileSystemCaches() {
+  std::vector<string> schemes;
+  TF_RETURN_IF_ERROR(GetRegisteredFileSystemSchemes(&schemes));
+  for (const string& scheme : schemes) {
+    FileSystem* fs = nullptr;
+    TF_RETURN_IF_ERROR(
+        GetFileSystemForFile(io::CreateURI(scheme, "", ""), &fs));
+    fs->FlushCaches();
+  }
+  return Status::OK();
+}
+
 Status Env::NewRandomAccessFile(const string& fname,
                                 std::unique_ptr<RandomAccessFile>* result) {
   FileSystem* fs;
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index a0adf70..557bfa8 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -68,10 +68,13 @@
   /// \brief Returns the file system schemes registered for this Env.
   virtual Status GetRegisteredFileSystemSchemes(std::vector<string>* schemes);
 
-  // \brief Register a file system for a scheme.
+  /// \brief Register a file system for a scheme.
   virtual Status RegisterFileSystem(const string& scheme,
                                     FileSystemRegistry::Factory factory);
 
+  /// \brief Flush filesystem caches for all registered filesystems.
+  Status FlushFileSystemCaches();
+
   /// \brief Creates a brand new random access read-only file with the
   /// specified name.
 
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 233c370..47ddf0c 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -281,6 +281,15 @@
     StringPiece scheme, host, path;
     io::ParseURI(dir, &scheme, &host, &path);
     if (path.empty()) return errors::NotFound(dir, " not found");
+    // The special "flushed" file exists only if the filesystem's caches have
+    // been flushed.
+    if (path == "/flushed") {
+      if (flushed_) {
+        return Status::OK();
+      } else {
+        return errors::NotFound("FlushCaches() not called yet");
+      }
+    }
     return Env::Default()->FileExists(io::JoinPath(BaseDir(), path));
   }
 
@@ -295,10 +304,23 @@
     }
     return Env::Default()->CreateDir(io::JoinPath(BaseDir(), path));
   }
+
+  void FlushCaches() override { flushed_ = true; }
+
+ private:
+  bool flushed_ = false;
 };
 
 REGISTER_FILE_SYSTEM("tmpdirfs", TmpDirFileSystem);
 
+TEST_F(DefaultEnvTest, FlushFileSystemCaches) {
+  Env* env = Env::Default();
+  const string flushed = "tmpdirfs://testhost/flushed";
+  EXPECT_EQ(error::Code::NOT_FOUND, env->FileExists(flushed).code());
+  TF_EXPECT_OK(env->FlushFileSystemCaches());
+  TF_EXPECT_OK(env->FileExists(flushed));
+}
+
 TEST_F(DefaultEnvTest, RecursivelyCreateDirWithUri) {
   Env* env = Env::Default();
   const string create_path = "tmpdirfs://testhost/a/b/c/d";
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index 938f5af..1475589 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -73,6 +73,8 @@
   return Status(tensorflow::error::FAILED_PRECONDITION, "Not a directory");
 }
 
+void FileSystem::FlushCaches() {}
+
 RandomAccessFile::~RandomAccessFile() {}
 
 WritableFile::~WritableFile() {}
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 903df96..d32efce 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -206,6 +206,9 @@
   ///  * UNIMPLEMENTED - The file factory doesn't support directories.
   virtual Status IsDirectory(const string& fname);
 
+  /// \brief Flushes any cached filesystem objects from memory.
+  virtual void FlushCaches();
+
   FileSystem() {}
 
   virtual ~FileSystem();
diff --git a/tensorflow/core/platform/stacktrace_handler.cc b/tensorflow/core/platform/stacktrace_handler.cc
new file mode 100644
index 0000000..ff31c97
--- /dev/null
+++ b/tensorflow/core/platform/stacktrace_handler.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/platform.h"
+
+#if !defined(PLATFORM_GOOGLE) && !defined(IS_MOBILE_PLATFORM) && \
+    defined(PLATFORM_POSIX) && (defined(__clang__) || defined(__GNUC__))
+#define TF_GENERATE_STACKTRACE
+#endif
+
+#if defined(TF_GENERATE_STACKTRACE)
+#include <errno.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/time.h>
+#include <unistd.h>
+#include <string>
+
+#include "tensorflow/core/platform/abi.h"
+#include "tensorflow/core/platform/stacktrace.h"
+
+#endif  // defined(TF_GENERATE_STACKTRACE)
+
+namespace tensorflow {
+namespace testing {
+
+#if defined(TF_GENERATE_STACKTRACE)
+// This function will print stacktrace to STDERR.
+// It avoids using malloc, so it makes sure to dump the stack even when the heap
+// is corrupted. However, it can dump mangled symbols.
+inline void SafePrintStackTrace() {
+  static const char begin_msg[] = "*** BEGIN MANGLED STACK TRACE ***\n";
+  (void)write(STDERR_FILENO, begin_msg, strlen(begin_msg));
+
+  int buffer_size = 128;
+  void *trace[128];
+  // Run backtrace to get the size of the stacktrace
+  buffer_size = backtrace(trace, buffer_size);
+
+  // Print a mangled stacktrace to STDERR as safely as possible.
+  backtrace_symbols_fd(trace, buffer_size, STDERR_FILENO);
+
+  static const char end_msg[] = "*** END MANGLED STACK TRACE ***\n\n";
+  (void)write(STDERR_FILENO, end_msg, strlen(end_msg));
+}
+
+static void StacktraceHandler(int sig, siginfo_t *si, void *v) {
+  // Make sure our handler does not deadlock. And this should be the last thing
+  // our program does. Therefore, set a timer to kill the program in 60
+  // seconds.
+  struct itimerval timer;
+  timer.it_value.tv_sec = 60;
+  timer.it_value.tv_usec = 0;
+  timer.it_interval.tv_sec = 0;
+  timer.it_interval.tv_usec = 0;
+  setitimer(ITIMER_REAL, &timer, 0);
+
+  struct sigaction sa_timeout;
+  memset(&sa_timeout, 0, sizeof(sa_timeout));
+  sa_timeout.sa_handler = SIG_DFL;
+  sigaction(SIGALRM, &sa_timeout, 0);
+
+  char buf[128];
+
+  snprintf(buf, sizeof(buf), "*** Received signal %d ***\n", sig);
+  (void)write(STDERR_FILENO, buf, strlen(buf));
+
+  // Print "a" stack trace, as safely as possible.
+  SafePrintStackTrace();
+
+  // Up until this line, we made sure not to allocate memory, to be able to dump
+  // a stack trace even in the event of heap corruption. After this line, we
+  // will try to print more human readable things to the terminal.
+  // But these have a higher probability to fail.
+  std::string stacktrace = CurrentStackTrace();
+  (void)write(STDERR_FILENO, stacktrace.c_str(), stacktrace.length());
+
+  // Abort the program.
+  struct sigaction sa;
+  sigemptyset(&sa.sa_mask);
+  sa.sa_flags = 0;
+  sa.sa_handler = SIG_DFL;
+  sigaction(SIGABRT, &sa, NULL);
+  abort();
+}
+
+void InstallStacktraceHandler() {
+  int handled_signals[] = {SIGSEGV, SIGABRT, SIGBUS, SIGILL, SIGFPE};
+
+  for (int i = 0; i < sizeof(handled_signals) / sizeof(int); i++) {
+    int sig = handled_signals[i];
+    struct sigaction sa;
+    struct sigaction osa;
+
+    sigemptyset(&sa.sa_mask);
+    sa.sa_flags = SA_SIGINFO | SA_RESETHAND;
+    sa.sa_sigaction = &StacktraceHandler;
+    if (sigaction(sig, &sa, &osa) != 0) {
+      char buf[128];
+      snprintf(buf, sizeof(buf),
+               "Warning, can't install backtrace signal handler for signal %d, "
+               "errno:%d \n",
+               sig, errno);
+      (void)write(STDERR_FILENO, buf, strlen(buf));
+    } else if (osa.sa_handler != SIG_DFL) {
+      char buf[128];
+      snprintf(buf, sizeof(buf),
+               "Warning, backtrace signal handler for signal %d overwrote "
+               "previous handler.\n",
+               sig);
+      (void)write(STDERR_FILENO, buf, strlen(buf));
+    }
+  }
+}
+
+#else
+void InstallStacktraceHandler() {}
+#endif  // defined(TF_GENERATE_STACKTRACE)
+
+}  // namespace testing
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/stacktrace_handler.h b/tensorflow/core/platform/stacktrace_handler.h
new file mode 100644
index 0000000..d36c82c
--- /dev/null
+++ b/tensorflow/core/platform/stacktrace_handler.h
@@ -0,0 +1,28 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+
+namespace tensorflow {
+namespace testing {
+
+// Installs signal handlers to print out stack trace.
+void InstallStacktraceHandler();
+
+}  // namespace testing
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
diff --git a/tensorflow/core/platform/stacktrace_handler_test.cc b/tensorflow/core/platform/stacktrace_handler_test.cc
new file mode 100644
index 0000000..958c7de
--- /dev/null
+++ b/tensorflow/core/platform/stacktrace_handler_test.cc
@@ -0,0 +1,82 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Testing proper operation of the stacktrace handler.
+
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/types.h>
+#include <unistd.h>
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+#define READ_BUFFER_SIZE 1024
+
+TEST(StacktraceHandlerTest, GeneratesStacktrace) {
+  // Create a pipe to write/read the child stdout.
+  int test_pipe[2];
+  EXPECT_EQ(pipe(test_pipe), 0);
+
+  // Fork the process.
+  int test_pid = fork();
+
+  if (test_pid == 0) {
+    // Child process.
+    // Close the read end of the pipe, redirect stdout and sleep.
+    close(test_pipe[0]);
+    dup2(test_pipe[1], STDOUT_FILENO);
+    dup2(test_pipe[1], STDERR_FILENO);
+    sleep(10);
+  } else {
+    // Parent process.
+    // Close the write end of the pipe, wait a little and send SIGABRT to the
+    // child process. Then watch the pipe.
+    close(test_pipe[1]);
+    sleep(1);
+
+    // Send the signal.
+    kill(test_pid, SIGABRT);
+
+    // Read from the pipe.
+    char buffer[READ_BUFFER_SIZE];
+    std::string child_output = "";
+    while (true) {
+      int read_length = read(test_pipe[0], buffer, READ_BUFFER_SIZE);
+      if (read_length > 0) {
+        child_output += std::string(buffer, read_length);
+      } else {
+        break;
+      }
+    }
+    close(test_pipe[0]);
+
+    // Just make sure we can detect one of the calls in testing stack.
+    string test_stack_frame = "testing::internal::UnitTestImpl::RunAllTests()";
+
+    // Print the stack trace detected for information.
+    LOG(INFO) << "Output from the child process:";
+    LOG(INFO) << child_output;
+
+    EXPECT_NE(child_output.find(test_stack_frame), std::string::npos);
+  }
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/platform/test_main.cc b/tensorflow/core/platform/test_main.cc
index 96c88af..677114f 100644
--- a/tensorflow/core/platform/test_main.cc
+++ b/tensorflow/core/platform/test_main.cc
@@ -27,12 +27,14 @@
 #include <iostream>
 
 #include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/stacktrace_handler.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 
 GTEST_API_ int main(int argc, char** argv) {
   std::cout << "Running main() from test_main.cc\n";
 
+  tensorflow::testing::InstallStacktraceHandler();
   testing::InitGoogleTest(&argc, argv);
   for (int i = 1; i < argc; i++) {
     if (tensorflow::StringPiece(argv[i]).starts_with("--benchmarks=")) {
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 6d6297d..47b1251 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -113,6 +113,29 @@
 [Stack Overflow](http://stackoverflow.com/questions/tagged/tensorflow)
 for possible solutions.  You can skip reading the rest of this document.
 
+### GPU support
+
+If your Linux system has an NVIDIA® GPU and your TensorFlow Java program
+requires GPU acceleration, then add the following to the project's `pom.xml`
+instead:
+
+```xml
+<dependency>
+  <groupId>org.tensorflow</groupId>
+  <artifactId>libtensorflow</artifactId>
+  <version>1.4.0</version>
+</dependency>
+<dependency>
+  <groupId>org.tensorflow</groupId>
+  <artifactId>libtensorflow_jni_gpu</artifactId>
+  <version>1.4.0</version>
+</dependency>
+```
+
+GPU acceleration is available via Maven only for Linux and only if your system
+meets the
+@{$install_linux#determine_which_tensorflow_to_install$requirements for GPU}.
+
 ## Using TensorFlow with JDK
 
 This section describes how to use TensorFlow using the `java` and `javac`
diff --git a/tensorflow/docs_src/performance/datasets_performance.md b/tensorflow/docs_src/performance/datasets_performance.md
index dd55849..4f95e17 100644
--- a/tensorflow/docs_src/performance/datasets_performance.md
+++ b/tensorflow/docs_src/performance/datasets_performance.md
@@ -224,7 +224,7 @@
 to:
 
 ```
-dataset = files.apply(tf.data.contrib.parallel_interleave(
+dataset = files.apply(tf.contrib.data.parallel_interleave(
     tf.data.TFRecordDataset, cycle_length=FLAGS.num_parallel_readers))
 ```
 
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index dceaf71..97467c5 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3719,6 +3719,7 @@
     srcs = ["training/session_manager_test.py"],
     additional_deps = [
         ":array_ops",
+        ":control_flow_ops",
         ":client",
         ":client_testlib",
         ":errors",
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index 00c38ea..909150e 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -154,8 +154,7 @@
       sess: A tensorflow Session object.
     """
 
-    _check_type(sess, (session.SessionInterface,
-                       monitored_session.MonitoredSession))
+    _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
     self.session = sess
 
 
@@ -359,8 +358,7 @@
       NotImplementedError: If a non-DirectSession sess object is received.
     """
 
-    _check_type(sess, (session.SessionInterface,
-                       monitored_session.MonitoredSession))
+    _check_type(sess, (session.BaseSession, monitored_session.MonitoredSession))
 
     # The session being wrapped.
     self._sess = sess
diff --git a/tensorflow/python/debug/wrappers/framework_test.py b/tensorflow/python/debug/wrappers/framework_test.py
index 5240e0d..73e08ce 100644
--- a/tensorflow/python/debug/wrappers/framework_test.py
+++ b/tensorflow/python/debug/wrappers/framework_test.py
@@ -271,9 +271,9 @@
   def testSessionInitInvalidSessionType(self):
     """Attempt to wrap a non-Session-type object should cause an exception."""
 
-    sess = "not a session"
+    wrapper = TestDebugWrapperSessionBadAction(self._sess)
     with self.assertRaisesRegexp(TypeError, "Expected type .*; got type .*"):
-      TestDebugWrapperSessionBadAction(sess)
+      TestDebugWrapperSessionBadAction(wrapper)
 
   def testSessionInitBadActionValue(self):
     with self.assertRaisesRegexp(
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 8aec242..3173afc 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -133,6 +133,9 @@
     """Set a global eager mode seed for random ops."""
     self._seed = seed
     self._rng = random.Random(self._seed)
+    # Also clear the kernel cache, to reset any existing seeds
+    if self._context_handle is not None:
+      pywrap_tensorflow.TFE_ContextClearCaches(self._context_handle)
 
   def _internal_operation_seed(self):
     """Returns a fake operation seed.
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 487f1b0..d9c1c3c 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -380,6 +380,93 @@
       self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
+  def testReduceSumAlongHWC(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2, 3])
+      output = array_ops.identity(reduce_sum)
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if node.name.startswith('LayoutOptimizerTranspose'):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      # Three transposes were initially added in the Expand phase of
+      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+      expected_num_transposes = 1
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
+  def testReduceSumAlongNHW(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2])
+      output = array_ops.identity(reduce_sum)
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if node.name.startswith('LayoutOptimizerTranspose'):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      # Three transposes were initially added in the Expand phase of
+      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+      expected_num_transposes = 1
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
+  def testReduceSumAlongC(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      reduce_sum = math_ops.reduce_sum(conv, axis=[3])
+      output = array_ops.identity(reduce_sum)
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if node.name.startswith('LayoutOptimizerTranspose'):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      # Three transposes were initially added in the Expand phase of
+      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+      expected_num_transposes = 1
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
   def testConcatWithControlDependency(self):
     if test.is_gpu_available(cuda_only=True):
       random_seed.set_random_seed(0)
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py
index 936da39..7650bfb 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py
@@ -399,7 +399,7 @@
       model.add(keras.layers.Activation('softmax'))
       model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
 
-      np.random.seed(1337)
+      np.random.seed(43)
       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
           train_samples=train_samples,
           test_samples=test_samples,
diff --git a/tensorflow/python/kernel_tests/random/random_ops_test.py b/tensorflow/python/kernel_tests/random/random_ops_test.py
index 56aaa53..5a2903a 100644
--- a/tensorflow/python/kernel_tests/random/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random/random_ops_test.py
@@ -21,6 +21,7 @@
 import numpy as np
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -174,6 +175,17 @@
       diff = rnd2 - rnd1
       self.assertTrue(np.linalg.norm(diff.eval()) > 0.1)
 
+  def testEagerSeed(self):
+    with context.eager_mode():
+      # Ensure a context has been created
+      random_ops.random_normal([])
+      # Set the same seed twice and check that the values match
+      context.set_global_seed(42)
+      rnd1 = random_ops.random_normal([])
+      context.set_global_seed(42)
+      rnd2 = random_ops.random_normal([])
+      self.assertAllEqual(rnd1, rnd2)
+
 
 class RandomUniformTest(test.TestCase):
 
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 126e39a..00faf3f 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -131,6 +131,9 @@
 
     self._init_set_name(name)
 
+    # Holds functions for creating regularizer ops.
+    self._regularizer_factories = []
+
     # Determine variable scope.
     scope = kwargs.get('_scope')
     if scope:
@@ -291,6 +294,22 @@
       inputs_hash = None
     return self._per_input_updates.get(inputs_hash, [])
 
+  def _get_regularizer_factories(self):
+    try:
+      # Some subclasses of Layer do not use its constructor.
+      return self._regularizer_factories
+    except AttributeError:
+      self._regularizer_factories = []
+      return self._regularizer_factories
+
+  def _maybe_create_variable_regularizers(self):
+    """Creates added but uninstantiated regularizers."""
+    factories = self._get_regularizer_factories()
+    if factories:
+      for factory in factories:
+        factory()
+      factories[:] = []
+
   @property
   def losses(self):
     """Losses which are associated with this `Layer`.
@@ -302,6 +321,7 @@
     Returns:
       A list of tensors.
     """
+    self._maybe_create_variable_regularizers()
     if context.in_eager_mode():
       # _losses may only contain variable regularization losses when executing
       # eagerly, and they have been saved as lambdas to be executed when
@@ -385,6 +405,7 @@
       inputs_hash = layers_util.object_list_uid(inputs)
     else:
       inputs_hash = None
+    self._maybe_create_variable_regularizers()
     return self._per_input_losses.get(inputs_hash, [])
 
   def build(self, _):
@@ -479,17 +500,20 @@
       instance is returned.
 
     Raises:
-      RuntimeError: If called in Eager mode with regularizers.
+      RuntimeError: If called in Eager mode with partioned variable
+        regularization.
     """
-    if context.in_graph_mode():
+
+    in_graph_mode = context.in_graph_mode()
+    if in_graph_mode:
       existing_variables = set(tf_variables.global_variables())
     if dtype is None:
       dtype = self.dtype or dtypes.float32
 
     self._set_scope(None)
+    reuse = self.built or self._reuse
     with vs.variable_scope(
-        self._scope, reuse=(self.built or self._reuse),
-        auxiliary_name_scope=False) as scope:
+        self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
       with ops.name_scope(self._name_scope_name(scope)):
         variable = vs.get_variable(name,
                                    shape=shape,
@@ -498,39 +522,56 @@
                                    constraint=constraint,
                                    trainable=trainable and self.trainable,
                                    partitioner=partitioner)
-        if context.in_graph_mode():
+
+        if in_graph_mode:
           if (trainable and self.trainable
               and variable not in tf_variables.trainable_variables()):
             # A custom getter / variable scope overrode the trainable flag.
             trainable = False
           if variable in existing_variables:
+            # To match the behavior of tf.get_variable(), we only apply
+            # regularization if the variable is newly created.
             return variable
-          if regularizer:
-            # To match the behavior of tf.get_variable(), we only
-            # apply regularization if the variable is newly created.
-            if isinstance(variable, tf_variables.PartitionedVariable):
-              for v in variable:
-                with ops.colocate_with(v.op):
-                  with ops.name_scope(name + '/Regularizer'):
-                    regularization = regularizer(v)
-                if regularization is not None:
-                  self.add_loss(regularization)
+
+        if regularizer:
+          def regularizer_factory():
+            if context.in_graph_mode():
+              with vs.variable_scope(scope, reuse=reuse,
+                                     auxiliary_name_scope=False):
+                with ops.name_scope(self._name_scope_name(scope)):
+                  if isinstance(variable, tf_variables.PartitionedVariable):
+                    for v in variable:
+                      with ops.colocate_with(v.op):
+                        with ops.name_scope(name + '/Regularizer'):
+                          regularization = regularizer(v)
+                      if regularization is not None:
+                        self.add_loss(regularization)
+                  else:
+                    with ops.colocate_with(variable.op):
+                      with ops.name_scope(name + '/Regularizer'):
+                        regularization = regularizer(variable)
+                    if regularization is not None:
+                      self.add_loss(regularization)
             else:
-              with ops.colocate_with(variable.op):
-                with ops.name_scope(name + '/Regularizer'):
-                  regularization = regularizer(variable)
-              if regularization is not None:
-                self.add_loss(regularization)
-        elif regularizer:
-          if isinstance(variable, tf_variables.PartitionedVariable):
-            raise RuntimeError(
-                'Partitioned variable regularization is not yet supported when '
-                'executing eagerly. File a feature request is this is '
-                'important to you.')
-          # Save a zero-argument lambda which runs the regularizer on the
-          # variable, to be executed when `Layer.losses` is requested. This
-          # makes losses responsive to variable updates when executing eagerly.
-          self._losses.append(lambda: regularizer(variable))
+              if isinstance(variable, tf_variables.PartitionedVariable):
+                raise RuntimeError(
+                    'Partitioned variable regularization is not yet '
+                    'supported when executing eagerly. File a feature request'
+                    'if this is important to you.')
+              # Save a zero-argument lambda which runs the regularizer on the
+              # variable, to be executed when `Layer.losses` is requested.
+              # This makes losses responsive to variable updates when
+              # executing eagerly.
+              self._losses.append(lambda: regularizer(variable))
+
+          if hasattr(self, '_defer_regularizers') and self._defer_regularizers:
+            # _defer_regularizers exists and is set to True if `build` was
+            # invoked in `__call__`: deferring regularizer construction
+            # prevents the regularizer from being created in an `init_scope`.
+            self._get_regularizer_factories().append(regularizer_factory)
+          else:
+            regularizer_factory()
+
     if trainable:
       self._trainable_weights.append(variable)
     else:
@@ -629,7 +670,15 @@
             except AttributeError:
               pass
           input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
-          self.build(input_shapes)
+
+          # Signal to `add_variable` that regularizer construction should be
+          # deferred.
+          self._defer_regularizers = True
+          with ops.init_scope():
+            self.build(input_shapes)
+          # Create any regularizers added by `build`.
+          self._maybe_create_variable_regularizers()
+          self._defer_regularizers = False
         try:
           # Note: not all sub-classes of Layer call Layer.__init__ (especially
           # the ones under tensorflow/python/keras). Hence we recompute this
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index c26b136..06ba214 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -531,6 +531,30 @@
       self.assertEqual(layer2.my_var.name, 'name_3/my_var:0')
       self.assertEqual(op2.name, 'name_3/my_op:0')
 
+  def testVariablesAreLiftedFromFunctionBuildingGraphs(self):
+    class MyLayer(base_layers.Layer):
+
+      def build(self, input_shape):
+        self.my_var = self.add_variable('my_var', (), dtypes.float32)
+        self.built = True
+
+      def call(self, inputs):
+        return inputs
+
+    outer_graph = ops.get_default_graph()
+    function_building_graph = ops.Graph()
+    function_building_graph._building_function = True
+    with outer_graph.as_default():
+      with function_building_graph.as_default():
+        layer = MyLayer()
+        # Create a variable by invoking build through __call__ and assert that
+        # it is both tracked and lifted into the outer graph.
+        inputs = array_ops.placeholder(dtypes.float32, (), 'inputs')
+        layer.apply(inputs)
+        self.assertEqual(len(layer.variables), 1)
+        self.assertEqual(len(layer.trainable_variables), 1)
+        self.assertEqual(layer.variables[0].graph, outer_graph)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 74b4056..88d1ce5 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1494,20 +1494,17 @@
       zero = ""
     else:
       zero = 0
-    # Checking for boolean dtype to prevent attempting to run fill on the GPU
-    # which does not have a boolean kernel registered.
-    if context.in_eager_mode() and dtype != dtypes.bool:
-      return fill(shape, constant(zero, dtype=dtype), name=name)
-    try:
-      if isinstance(shape, ops.Tensor):
-        # TODO(apassos) this is required to reproduce the behavior from before
-        # Tensors were iterable. It's a crutch.
-        raise TypeError
-      shape = tensor_shape.as_shape(shape)
-      output = constant(zero, shape=shape, dtype=dtype, name=name)
-    except (TypeError, ValueError):
-      shape = ops.convert_to_tensor(shape, dtype=dtypes.int32, name="shape")
-      output = fill(shape, constant(zero, dtype=dtype), name=name)
+    if not isinstance(shape, ops.Tensor):
+      try:
+        # Go through tensor shapes to get int64-if-needed semantics
+        shape = constant_op._tensor_shape_tensor_conversion_function(
+            tensor_shape.TensorShape(shape))
+      except (TypeError, ValueError):
+        # Happens when shape is a list with tensor elements
+        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
+    if not shape._shape_tuple():
+      shape = reshape(shape, [-1])  # Ensure it's a vector
+    output = fill(shape, constant(zero, dtype=dtype), name=name)
   assert output.dtype.base_dtype == dtype
   return output
 
@@ -1625,15 +1622,17 @@
   dtype = dtypes.as_dtype(dtype).base_dtype
   with ops.name_scope(name, "ones", [shape]) as name:
     one = True if dtype == dtypes.bool else 1
-    try:
-      if isinstance(shape, ops.Tensor):
-        raise TypeError(
-            "preserving semantics from before tensors were iterable")
-      shape = tensor_shape.as_shape(shape)
-      output = constant(one, shape=shape, dtype=dtype, name=name)
-    except (TypeError, ValueError):
-      shape = ops.convert_to_tensor(shape, dtype=dtypes.int32, name="shape")
-      output = fill(shape, constant(one, dtype=dtype), name=name)
+    if not isinstance(shape, ops.Tensor):
+      try:
+        # Go through tensor shapes to get int64-if-needed semantics
+        shape = constant_op._tensor_shape_tensor_conversion_function(
+            tensor_shape.TensorShape(shape))
+      except (TypeError, ValueError):
+        # Happens when shape is a list with tensor elements
+        shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
+    if not shape._shape_tuple():
+      shape = reshape(shape, [-1])  # Ensure it's a vector
+    output = fill(shape, constant(one, dtype=dtype), name=name)
   assert output.dtype.base_dtype == dtype
   return output
 
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index ab7a903..879c206 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -276,10 +276,6 @@
           dtype=dtype,
           constraint=constraint)
 
-  # LINT.IfChange
-  # _VariableFromResource inherits from ResourceVariable but
-  # doesn't call the constructor, so changes here might need to be reflected
-  # there.
   # pylint: disable=unused-argument
   def _init_from_args(self,
                       initial_value=None,
@@ -438,7 +434,8 @@
               self._initializer_op = (
                   gen_resource_variable_ops.assign_variable_op(
                       self._handle,
-                      self._build_initializer_expr(initial_value),
+                      self._try_guard_against_uninitialized_dependencies(
+                          initial_value),
                       name=n))
           with ops.name_scope("Read"), ops.colocate_with(self._handle):
             # Manually assign reads to the handle's device to avoid log
@@ -522,7 +519,6 @@
     self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
     self._graph_element = self.value()
     self._constraint = None
-  # LINT.ThenChange(//tensorflow/python/eager/graph_callable.py)
 
   def __nonzero__(self):
     return self.__bool__()
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 07796b2..e0cf1bf 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -127,8 +127,8 @@
 
   Returns:
     A function to encapsulate a set of variables which should be created once
-    and reused. An enclosing scope will created, either where `make_template`
-    is called, or wherever the result is called, depending on the value of
+    and reused. An enclosing scope will be created either when `make_template`
+    is called or when the result is called, depending on the value of
     `create_scope_now_`. Regardless of the value, the first time the template
     is called it will enter the scope with no reuse, and call `func_` to create
     variables, which are guaranteed to be unique. All subsequent calls will
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index e0748d8..b258556 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -362,7 +362,8 @@
         # using their initialized_value() method.
         self._initializer_op = state_ops.assign(
             self._variable,
-            self._build_initializer_expr(self._initial_value),
+            self._try_guard_against_uninitialized_dependencies(
+                self._initial_value),
             validate_shape=validate_shape).op
 
         # TODO(vrv): Change this class to not take caching_device, but
@@ -781,88 +782,142 @@
 
     setattr(Variable, operator, _run_op)
 
-  def _build_initializer_expr(self, initial_value):
-    """Build an expression suitable to initialize a variable.
+  def _try_guard_against_uninitialized_dependencies(self, initial_value):
+    """Attempt to guard against dependencies on uninitialized variables.
 
-    Replace references to variables in initial_value with references to the
-    variable initial values instead.
+    Replace references to variables in `initial_value` with references to the
+    variable's initialized values. The initialized values are essentially
+    conditional TensorFlow graphs that return a variable's value if it is
+    initialized or its `initial_value` if it hasn't been initialized. This
+    replacement is done on a best effort basis:
+
+    - If the `initial_value` graph contains cycles, we don't do any
+      replacements for that graph.
+    - If the variables that `initial_value` depends on are not present in the
+      `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
+
+    In these cases, it is up to the caller to ensure that the `initial_value`
+    graph uses initialized variables or that they guard access to variables
+    using their `initialized_value` method.
 
     Args:
-      initial_value: original expression
+      initial_value: `Tensor`. The initial value.
     Returns:
-      A tensorflow expression suitable to initialize a variable.
+      A `Tensor` suitable to initialize a variable.
+    Raises:
+      TypeError: If `initial_value` is not a `Tensor`.
     """
-    if isinstance(initial_value, Variable):
-      return initial_value.initialized_value()
-    elif isinstance(initial_value, ops.Tensor):
-      new_op = self._build_initializer_expr(initial_value.op)
-      if new_op != initial_value.op:
-        if isinstance(new_op, ops.Tensor):
-          return new_op
-        else:
-          return ops.Tensor(new_op, initial_value.value_index,
-                            initial_value.dtype)
-      else:
-        return initial_value
-    elif isinstance(initial_value, ops.Operation):
-      if initial_value.node_def.op in [
-          "IsVariableInitialized", "VarIsInitializedOp", "ReadVariableOp"
-      ]:
-        return initial_value
-      if initial_value.node_def.op in ["Variable", "VariableV2", "VarHandleOp"]:
-        return self._find_initialized_value_for_variable(initial_value)
-      modified = False
-      new_inputs = []
-      for tensor in initial_value.inputs:
-        new_tensor = self._build_initializer_expr(tensor)
-        new_inputs.append(new_tensor)
-        if new_tensor != tensor:
-          modified = True
+    if not isinstance(initial_value, ops.Tensor):
+      raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
 
-      if modified:
-        new_name = initial_value.node_def.name + "_" + self.name
-        new_name = new_name.replace(":", "_")
-        new_op = initial_value.node_def.op
-        new_op = new_op.replace("RefSwitch", "Switch")
-        new_value = self.graph.create_op(
-            new_op,
-            new_inputs,
-            # pylint: disable=protected-access
-            initial_value._output_types,
-            # pylint: enable=protected-access
-            name=new_name,
-            attrs=initial_value.node_def.attr)
-        return new_value
-      else:
-        return initial_value
-    else:
+    # Don't modify initial_value if it contains any cyclic dependencies.
+    def has_cycle(op, path):
+      """Detect cycles in the dependencies of `initial_value`."""
+      if op.name in path:
+        return True
+      path.add(op.name)
+      for op_input in op.inputs:
+        if has_cycle(op_input.op, path):
+          return True
+      for op_control_input in op.control_inputs:
+        if has_cycle(op_control_input, path):
+          return True
+      path.remove(op.name)
+      return False
+    if has_cycle(initial_value.op, path=set()):
       return initial_value
 
+    return self._safe_initial_value_from_tensor(initial_value, op_cache={})
+
+  def _safe_initial_value_from_tensor(self, tensor, op_cache):
+    """Replace dependencies on variables with their initialized values.
+
+    Args:
+      tensor: A `Tensor`. The tensor to replace.
+      op_cache: A dict mapping operation names to `Operation`s. Used to memoize
+        the results so as to avoid creating redundant operations.
+    Returns:
+      A `Tensor` compatible with `tensor`. Any inputs that lead to variable
+      values will be replaced with a corresponding graph that uses the
+      variable's initialized values. This is done on a best-effort basis. If no
+      modifications need to be made then `tensor` will be returned unchanged.
+    """
+    op = tensor.op
+    new_op = op_cache.get(op.name)
+    if new_op is None:
+      new_op = self._safe_initial_value_from_op(op, op_cache)
+      op_cache[op.name] = new_op
+    return new_op.outputs[tensor.value_index]
+
+  def _safe_initial_value_from_op(self, op, op_cache):
+    """Replace dependencies on variables with their initialized values.
+
+    Args:
+      op: An `Operation`. The operation to replace.
+      op_cache: A dict mapping operation names to `Operation`s. Used to memoize
+        the results so as to avoid creating redundant operations.
+    Returns:
+      An `Operation` compatible with `op`. Any inputs that lead to variable
+      values will be replaced with a corresponding graph that uses the
+      variable's initialized values. This is done on a best-effort basis. If no
+      modifications need to be made then `op` will be returned unchanged.
+    """
+    op_type = op.node_def.op
+    if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
+                   "ReadVariableOp"):
+      return op
+
+    # Attempt to find the initialized_value of any variable reference / handles.
+    # TODO(b/70206927): Fix handling of ResourceVariables.
+    if op_type in ("Variable", "VariableV2", "VarHandleOp"):
+      initialized_value = self._find_initialized_value_for_variable(op)
+      return op if initialized_value is None else initialized_value.op
+
+    # Recursively build initializer expressions for inputs.
+    modified = False
+    new_op_inputs = []
+    for op_input in op.inputs:
+      new_op_input = self._safe_initial_value_from_tensor(op_input, op_cache)
+      new_op_inputs.append(new_op_input)
+      modified = modified or (new_op_input != op_input)
+
+    # If at least one input was modified, replace the op.
+    if modified:
+      new_op_type = op_type
+      if new_op_type == "RefSwitch":
+        new_op_type = "Switch"
+      new_op_name = op.node_def.name + "_" + self.name
+      new_op_name = new_op_name.replace(":", "_")
+      return self.graph.create_op(
+          new_op_type, new_op_inputs,
+          op._output_types,  # pylint: disable=protected-access
+          name=new_op_name, attrs=op.node_def.attr)
+
+    return op
+
   def _find_initialized_value_for_variable(self, variable_op):
-    """Find the initial value for a variable op.
+    """Find the initialized value for a variable op.
 
     To do so, lookup the variable op in the variables collection.
 
     Args:
-      variable_op: a TensorFlow variable Operation
+      variable_op: A variable `Operation`.
     Returns:
-      The initial value for the variable.
+      A `Tensor` representing the initialized value for the variable or `None`
+      if the initialized value could not be found.
     """
     try:
       var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
-      global_vars = self.graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
-      for var in global_vars:
-        if var.name in var_names:
-          return var.initialized_value()
-      local_vars = self.graph.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
-      for var in local_vars:
-        if var.name == var_names:
-          return var.initialized_value()
+      for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
+                              ops.GraphKeys.LOCAL_VARIABLES):
+        for var in self.graph.get_collection(collection_name):
+          if var.name in var_names:
+            return var.initialized_value()
     except AttributeError:
-      # Return the variable itself when an incomplete user defined variable type
-      # was put in the collection.
-      return variable_op
-    return variable_op
+      # Return None when an incomplete user-defined variable type was put in
+      # the collection.
+      return None
+    return None
 
   # NOTE(mrry): This enables the Variable's overloaded "right" binary
   # operators to run when the left operand is an ndarray, because it
diff --git a/tensorflow/python/profiler/internal/print_model_analysis_test.py b/tensorflow/python/profiler/internal/print_model_analysis_test.py
index 797c430..186c028 100644
--- a/tensorflow/python/profiler/internal/print_model_analysis_test.py
+++ b/tensorflow/python/profiler/internal/print_model_analysis_test.py
@@ -18,22 +18,13 @@
 from __future__ import division
 from __future__ import print_function
 
-from google.protobuf import text_format
-
-from tensorflow.core.profiler import tfprof_options_pb2
-from tensorflow.core.profiler import tfprof_output_pb2
-from tensorflow.python.client import session
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.platform import test
 
-# pylint: disable=g-bad-import-order
-# XXX: this depends on pywrap_tensorflow and must come later
-from tensorflow.python import pywrap_tensorflow as print_mdl
 
 # pylint: disable=bad-whitespace
 # pylint: disable=bad-continuation
@@ -69,407 +60,6 @@
     x = nn_ops.conv2d(image, kernel, [1, 2, 2, 1], padding='SAME')
     return x
 
-  def testPrintModelAnalysis(self):
-    opts = tfprof_options_pb2.OptionsProto()
-    opts.max_depth = TEST_OPTIONS['max_depth']
-    opts.min_bytes = TEST_OPTIONS['min_bytes']
-    opts.min_micros = TEST_OPTIONS['min_micros']
-    opts.min_params = TEST_OPTIONS['min_params']
-    opts.min_float_ops = TEST_OPTIONS['min_float_ops']
-    opts.order_by = TEST_OPTIONS['order_by']
-    opts.step = -1
-    for p in TEST_OPTIONS['account_type_regexes']:
-      opts.account_type_regexes.append(p)
-    for p in TEST_OPTIONS['start_name_regexes']:
-      opts.start_name_regexes.append(p)
-    for p in TEST_OPTIONS['trim_name_regexes']:
-      opts.trim_name_regexes.append(p)
-    for p in TEST_OPTIONS['show_name_regexes']:
-      opts.show_name_regexes.append(p)
-    for p in TEST_OPTIONS['hide_name_regexes']:
-      opts.hide_name_regexes.append(p)
-    opts.account_displayed_op_only = TEST_OPTIONS['account_displayed_op_only']
-    for p in TEST_OPTIONS['select']:
-      opts.select.append(p)
-    opts.output = TEST_OPTIONS['output']
-
-    with session.Session() as sess, ops.device('/cpu:0'):
-      _ = self._BuildSmallModel()
-      tfprof_pb = tfprof_output_pb2.GraphNodeProto()
-      tfprof_pb.ParseFromString(
-          print_mdl.PrintModelAnalysis(
-              sess.graph.as_graph_def(add_shapes=True).SerializeToString(),
-              b'',
-              b'',
-              b'scope',
-              opts.SerializeToString()))
-
-      expected_pb = tfprof_output_pb2.GraphNodeProto()
-      text_format.Merge(r"""name: "_TFProfRoot"
-          exec_micros: 0
-          requested_bytes: 0
-          total_exec_micros: 0
-          total_requested_bytes: 0
-          total_parameters: 648
-          children {
-            name: "Conv2D"
-            exec_micros: 0
-            requested_bytes: 0
-            total_exec_micros: 0
-            total_requested_bytes: 0
-            total_parameters: 0
-            float_ops: 0
-            total_float_ops: 0
-            input_shapes {
-              key: 0
-              value {
-                dim {
-                  size: 2
-                }
-                dim {
-                  size: 6
-                }
-                dim {
-                  size: 6
-                }
-                dim {
-                  size: 3
-                }
-              }
-            }
-            input_shapes {
-              key: 1
-              value {
-                dim {
-                  size: 6
-                }
-                dim {
-                  size: 6
-                }
-                dim {
-                  size: 3
-                }
-                dim {
-                  size: 6
-                }
-              }
-            }
-            accelerator_exec_micros: 0
-            cpu_exec_micros: 0
-            total_accelerator_exec_micros: 0
-            total_cpu_exec_micros: 0
-            run_count: 0
-            total_run_count: 0
-            total_definition_count: 1
-          }
-          children {
-            name: "DW"
-            exec_micros: 0
-            requested_bytes: 0
-            parameters: 648
-            total_exec_micros: 0
-            total_requested_bytes: 0
-            total_parameters: 648
-            children {
-              name: "DW/Assign"
-              exec_micros: 0
-              requested_bytes: 0
-              total_exec_micros: 0
-              total_requested_bytes: 0
-              total_parameters: 0
-              float_ops: 0
-              total_float_ops: 0
-              input_shapes {
-                key: 0
-                value {
-                  dim {
-                    size: 6
-                  }
-                  dim {
-                    size: 6
-                  }
-                  dim {
-                    size: 3
-                  }
-                  dim {
-                    size: 6
-                  }
-                }
-              }
-              input_shapes {
-                key: 1
-                value {
-                  dim {
-                    size: 6
-                  }
-                  dim {
-                    size: 6
-                  }
-                  dim {
-                    size: 3
-                  }
-                  dim {
-                    size: 6
-                  }
-                }
-              }
-              accelerator_exec_micros: 0
-              cpu_exec_micros: 0
-              total_accelerator_exec_micros: 0
-              total_cpu_exec_micros: 0
-              run_count: 0
-              total_run_count: 0
-              total_definition_count: 1
-            }
-            children {
-              name: "DW/Initializer"
-              exec_micros: 0
-              requested_bytes: 0
-              total_exec_micros: 0
-              total_requested_bytes: 0
-              total_parameters: 0
-              children {
-                name: "DW/Initializer/random_normal"
-                exec_micros: 0
-                requested_bytes: 0
-                total_exec_micros: 0
-                total_requested_bytes: 0
-                total_parameters: 0
-                children {
-                  name: "DW/Initializer/random_normal/RandomStandardNormal"
-                  exec_micros: 0
-                  requested_bytes: 0
-                  total_exec_micros: 0
-                  total_requested_bytes: 0
-                  total_parameters: 0
-                  float_ops: 0
-                  total_float_ops: 0
-                  input_shapes {
-                    key: 0
-                    value {
-                      dim {
-                        size: 4
-                      }
-                    }
-                  }
-                  accelerator_exec_micros: 0
-                  cpu_exec_micros: 0
-                  total_accelerator_exec_micros: 0
-                  total_cpu_exec_micros: 0
-                  run_count: 0
-                  total_run_count: 0
-                  total_definition_count: 1
-                }
-                children {
-                  name: "DW/Initializer/random_normal/mean"
-                  exec_micros: 0
-                  requested_bytes: 0
-                  total_exec_micros: 0
-                  total_requested_bytes: 0
-                  total_parameters: 0
-                  float_ops: 0
-                  total_float_ops: 0
-                  accelerator_exec_micros: 0
-                  cpu_exec_micros: 0
-                  total_accelerator_exec_micros: 0
-                  total_cpu_exec_micros: 0
-                  run_count: 0
-                  total_run_count: 0
-                  total_definition_count: 1
-                }
-                children {
-                  name: "DW/Initializer/random_normal/mul"
-                  exec_micros: 0
-                  requested_bytes: 0
-                  total_exec_micros: 0
-                  total_requested_bytes: 0
-                  total_parameters: 0
-                  float_ops: 0
-                  total_float_ops: 0
-                  input_shapes {
-                    key: 0
-                    value {
-                      dim {
-                        size: 6
-                      }
-                      dim {
-                        size: 6
-                      }
-                      dim {
-                        size: 3
-                      }
-                      dim {
-                        size: 6
-                      }
-                    }
-                  }
-                  input_shapes {
-                    key: 1
-                    value {
-                      dim {
-                        size: 1
-                      }
-                    }
-                  }
-                  accelerator_exec_micros: 0
-                  cpu_exec_micros: 0
-                  total_accelerator_exec_micros: 0
-                  total_cpu_exec_micros: 0
-                  run_count: 0
-                  total_run_count: 0
-                  total_definition_count: 1
-                }
-                children {
-                  name: "DW/Initializer/random_normal/shape"
-                  exec_micros: 0
-                  requested_bytes: 0
-                  total_exec_micros: 0
-                  total_requested_bytes: 0
-                  total_parameters: 0
-                  float_ops: 0
-                  total_float_ops: 0
-                  accelerator_exec_micros: 0
-                  cpu_exec_micros: 0
-                  total_accelerator_exec_micros: 0
-                  total_cpu_exec_micros: 0
-                  run_count: 0
-                  total_run_count: 0
-                  total_definition_count: 1
-                }
-                children {
-                  name: "DW/Initializer/random_normal/stddev"
-                  exec_micros: 0
-                  requested_bytes: 0
-                  total_exec_micros: 0
-                  total_requested_bytes: 0
-                  total_parameters: 0
-                  float_ops: 0
-                  total_float_ops: 0
-                  accelerator_exec_micros: 0
-                  cpu_exec_micros: 0
-                  total_accelerator_exec_micros: 0
-                  total_cpu_exec_micros: 0
-                  run_count: 0
-                  total_run_count: 0
-                  total_definition_count: 1
-                }
-                float_ops: 0
-                total_float_ops: 0
-                input_shapes {
-                  key: 0
-                  value {
-                    dim {
-                      size: 6
-                    }
-                    dim {
-                      size: 6
-                    }
-                    dim {
-                      size: 3
-                    }
-                    dim {
-                      size: 6
-                    }
-                  }
-                }
-                input_shapes {
-                  key: 1
-                  value {
-                    dim {
-                      size: 1
-                    }
-                  }
-                }
-                accelerator_exec_micros: 0
-                cpu_exec_micros: 0
-                total_accelerator_exec_micros: 0
-                total_cpu_exec_micros: 0
-                run_count: 0
-                total_run_count: 0
-                total_definition_count: 6
-              }
-              float_ops: 0
-              total_float_ops: 0
-              accelerator_exec_micros: 0
-              cpu_exec_micros: 0
-              total_accelerator_exec_micros: 0
-              total_cpu_exec_micros: 0
-              run_count: 0
-              total_run_count: 0
-              total_definition_count: 7
-            }
-            children {
-              name: "DW/read"
-              exec_micros: 0
-              requested_bytes: 0
-              total_exec_micros: 0
-              total_requested_bytes: 0
-              total_parameters: 0
-              float_ops: 0
-              total_float_ops: 0
-              input_shapes {
-                key: 0
-                value {
-                  dim {
-                    size: 6
-                  }
-                  dim {
-                    size: 6
-                  }
-                  dim {
-                    size: 3
-                  }
-                  dim {
-                    size: 6
-                  }
-                }
-              }
-              accelerator_exec_micros: 0
-              cpu_exec_micros: 0
-              total_accelerator_exec_micros: 0
-              total_cpu_exec_micros: 0
-              run_count: 0
-              total_run_count: 0
-              total_definition_count: 1
-            }
-            float_ops: 0
-            total_float_ops: 0
-            accelerator_exec_micros: 0
-            cpu_exec_micros: 0
-            total_accelerator_exec_micros: 0
-            total_cpu_exec_micros: 0
-            run_count: 0
-            total_run_count: 0
-            total_definition_count: 10
-          }
-          children {
-            name: "zeros"
-            exec_micros: 0
-            requested_bytes: 0
-            total_exec_micros: 0
-            total_requested_bytes: 0
-            total_parameters: 0
-            float_ops: 0
-            total_float_ops: 0
-            accelerator_exec_micros: 0
-            cpu_exec_micros: 0
-            total_accelerator_exec_micros: 0
-            total_cpu_exec_micros: 0
-            run_count: 0
-            total_run_count: 0
-            total_definition_count: 1
-          }
-          float_ops: 0
-          total_float_ops: 0
-          accelerator_exec_micros: 0
-          cpu_exec_micros: 0
-          total_accelerator_exec_micros: 0
-          total_cpu_exec_micros: 0
-          run_count: 0
-          total_run_count: 0
-          total_definition_count: 13""", expected_pb)
-      self.assertEqual(expected_pb, tfprof_pb)
-
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index a379bd5..14ad9e5 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -164,13 +164,6 @@
       model_analyzer.profile(
           sess.graph, run_meta, options=opts)
 
-      with gfile.Open(outfile, 'r') as f:
-        # pylint: disable=line-too-long
-        self.assertEqual(
-            'node name | # parameters | # float_ops | assigned devices | op types | op count (run|defined) | input shapes\n_TFProfRoot (--/451 params, --/11.34k flops, _kTFScopeParent, --/8|--/36, )\n  Conv2D (0/0 params, 5.83k/5.83k flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Conv2D, 1/1|1/1, 0:2x6x6x3|1:3x3x3x6)\n  Conv2D_1 (0/0 params, 4.61k/4.61k flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Conv2D, 1/1|1/1, 0:2x3x3x6|1:2x2x6x12)\n  DW (3x3x3x6, 162/162 params, 0/324 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n    DW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:3x3x3x6|1:3x3x3x6)\n    DW/Initializer (0/0 params, 0/324 flops, _kTFScopeParent, 0/0|1/7, )\n      DW/Initializer/random_normal (0/0 params, 162/324 flops, Add, 0/0|1/6, 0:3x3x3x6|1:1)\n        DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n        DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n        DW/Initializer/random_normal/mul (0/0 params, 162/162 flops, Mul, 0/0|1/1, 0:3x3x3x6|1:1)\n        DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n        DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n    DW/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Identity, 1/1|1/1, 0:3x3x3x6)\n  DW2 (2x2x6x12, 288/288 params, 0/576 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n    DW2/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:2x2x6x12|1:2x2x6x12)\n    DW2/Initializer (0/0 params, 0/576 flops, _kTFScopeParent, 0/0|1/7, )\n      DW2/Initializer/random_normal (0/0 params, 288/576 flops, Add, 0/0|1/6, 0:2x2x6x12|1:1)\n        DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n        DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n        DW2/Initializer/random_normal/mul (0/0 params, 288/288 flops, Mul, 0/0|1/1, 0:2x2x6x12|1:1)\n        DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n        DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n    DW2/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Identity, 1/1|1/1, 0:2x2x6x12)\n  ScalarW (1, 1/1 params, 0/2 flops, VariableV2|_trainable_variables, 0/0|1/10, )\n    ScalarW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:1|1:1)\n    ScalarW/Initializer (0/0 params, 0/2 flops, _kTFScopeParent, 0/0|1/7, )\n      ScalarW/Initializer/random_normal (0/0 params, 1/2 flops, Add, 0/0|1/6, 0:1|1:1)\n        ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:0)\n        ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n        ScalarW/Initializer/random_normal/mul (0/0 params, 1/1 flops, Mul, 0/0|1/1, 0:1|1:1)\n        ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n        ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n    ScalarW/read (0/0 params, 0/0 flops, Identity, 0/0|1/1, 0:1)\n  _retval_Conv2D_1_0_0 (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|_retval_Conv2D_1_0_0, 1/1|1/1, )\n  init (0/0 params, 0/0 flops, NoOp, 0/0|1/1, 0:1|1:3x3x3x6|2:2x2x6x12)\n  zeros (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/device:cpu:0, /job:localhost/replica:0/task:0/device:cpu:0|Const, 1/1|1/1, )\n',
-            f.read())
-        # pylint: enable=line-too-long
-
   def testSimpleCodeView(self):
     ops.reset_default_graph()
     outfile = os.path.join(test.get_temp_dir(), 'dump')
@@ -376,7 +369,6 @@
         self.assertLessEqual(len(tfprof_node.graph_nodes), last_occurrence)
         last_occurrence = len(tfprof_node.graph_nodes)
 
-      self.assertEqual(total_children, 15)
       self.assertGreater(input_shapes, 0)
 
   def testAdvisor(self):
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 82750e9..d97823c 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -20,6 +20,7 @@
 %rename("%s") TFE_ContextListDevices;
 %rename("%s") TFE_ContextAddFunction;
 %rename("%s") TFE_ContextAddFunctionDef;
+%rename("%s") TFE_ContextClearCaches;
 %rename("%s") TFE_OpNameGetAttrType;
 %rename("%s") TFE_Py_InitEagerTensor;
 %rename("%s") TFE_Py_RegisterExceptionClass;
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index 5879fd3..6670d93 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -26,6 +26,7 @@
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
@@ -504,6 +505,7 @@
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="x")
+      # TODO(b/70206927): Use ResourceVariables once they are handled properly.
       v_res = variables.Variable(1, name="v_res")
       w_res = variables.Variable(
           v_res,
@@ -556,6 +558,24 @@
       self.assertEquals(1, sess.run(w_res))
       self.assertEquals(3, sess.run(x_res))
 
+  def testPrepareSessionWithCyclicInitializer(self):
+    # Regression test. Previously Variable._build_initializer_expr would enter
+    # into an infinite recursion when the variable's initial_value involved
+    # cyclic dependencies.
+    with ops.Graph().as_default():
+      i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
+      v = variables.Variable(array_ops.identity(i), name="v")
+      with self.test_session():
+        self.assertEqual(False, variables.is_variable_initialized(v).eval())
+      sm = session_manager.SessionManager(
+          ready_op=variables.report_uninitialized_variables())
+      sess = sm.prepare_session("", init_op=v.initializer)
+      self.assertEqual(1, sess.run(v))
+      self.assertEqual(
+          True,
+          variables.is_variable_initialized(
+              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
+
   def testPrepareSessionDidNotInitLocalVariable(self):
     with ops.Graph().as_default():
       v = variables.Variable(1, name="v")
diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
index 80f2b59..6271fce 100755
--- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
+++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh
@@ -73,13 +73,16 @@
 
 # Zip up the .dll, LICENSE and include files for the C library.
 mkdir -p ${DIR}/include/tensorflow/c
+mkdir -p ${DIR}/include/tensorflow/c/eager
 mkdir -p ${DIR}/lib
 cp bazel-bin/tensorflow/libtensorflow.so ${DIR}/lib/tensorflow.dll
 cp tensorflow/c/c_api.h ${DIR}/include/tensorflow/c
+cp tensorflow/c/eager/c_api.h ${DIR}/include/tensorflow/c/eager
 cp bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE ${DIR}/include/tensorflow/c
 cd ${DIR}
 zip -j libtensorflow-cpu-windows-$(uname -m).zip \
   lib/tensorflow.dll \
+  include/tensorflow/c/eager/c_api.h \
   include/tensorflow/c/c_api.h \
   include/tensorflow/c/LICENSE
 rm -rf lib include
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index 845bad5..dbc8159 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -55,7 +55,10 @@
 
 pkg_tar(
     name = "cheaders",
-    files = ["//tensorflow/c:headers"],
+    files = [
+        "//tensorflow/c:headers",
+        "//tensorflow/c/eager:headers",
+    ],
     package_dir = "include/tensorflow/c",
     # Mark as "manual" till
     # https://github.com/bazelbuild/bazel/issues/2352
diff --git a/tools/bazel.rc b/tools/bazel.rc
index 04c24d7..44bfe9f 100644
--- a/tools/bazel.rc
+++ b/tools/bazel.rc
@@ -1,3 +1,28 @@
+# Android configs
+build:android --crosstool_top=//external:android/crosstool
+build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
+build:android_arm --config=android
+build:android_arm --cpu=armeabi-v7a
+build:android_arm64 --config=android
+build:android_arm64 --cpu=arm64-v8a
+
+# Config to use a mostly-static build and disable modular op registration
+# support (this will revert to loading TensorFlow with RTLD_GLOBAL in Python).
+# By default, TensorFlow will build with a dependence on
+# //tensorflow:libtensorflow_framework.so.
+build:monolithic --define framework_shared_object=false
+
+# For projects which use TensorFlow as part of a Bazel build process, putting
+# nothing in a bazelrc will default to a monolithic build. The following line
+# opts in to modular op registration support by default.
+build --define framework_shared_object=true
+
+# Please note that MKL on MacOS or windows is still not supported.
+# If you would like to use a local MKL instead of downloading, please set the
+# environment variable "TF_MKL_ROOT" every time before build.
+build:mkl --define=using_mkl=true
+build:mkl -c opt
+
 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
 build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true