Merge pull request #26631 from trevor-m:tmorris_tftrt_dont_settype_trt_5_1_3

PiperOrigin-RevId: 238142352
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index c4d3676..a7fb772 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -11,6 +11,7 @@
 
 load(
     "//tensorflow:tensorflow.bzl",
+    "tf_cc_shared_object",
     "tf_cc_test",
     "tf_copts",
     "tf_cuda_library",
@@ -46,19 +47,6 @@
     ]),
 )
 
-tf_custom_op_library(
-    name = "python/ops/_trt_ops.so",
-    srcs = [
-        "ops/get_serialized_resource_op.cc",
-        "ops/trt_engine_op.cc",
-    ],
-    deps = [
-        "//tensorflow/core:lib_proto_parsing",
-    ] + if_tensorrt([
-        "@local_config_tensorrt//:tensorrt",
-    ]),
-)
-
 cc_library(
     name = "trt_op_kernels",
     srcs = [
@@ -86,6 +74,22 @@
     alwayslink = 1,
 )
 
+tf_cc_shared_object(
+    name = "python/ops/libtftrt.so",
+    srcs = [
+        "ops/get_serialized_resource_op.cc",
+        "ops/trt_engine_op.cc",
+    ],
+    copts = tf_copts(is_external = True),
+    linkopts = ["-lm"],
+    deps = [
+        ":trt_op_kernels",
+        "//tensorflow/core:lib_proto_parsing",
+    ] + if_tensorrt([
+        "@local_config_tensorrt//:tensorrt",
+    ]) + tf_custom_op_library_additional_deps(),
+)
+
 tf_cuda_cc_test(
     name = "get_serialized_resource_op_test",
     size = "small",
@@ -149,7 +153,7 @@
     name = "trt_ops_loader",
     srcs = ["python/ops/trt_ops.py"],
     dso = [
-        "python/ops/_trt_ops.so",
+        "python/ops/libtftrt.so",
     ] + if_tensorrt([
         "@local_config_tensorrt//:tensorrt",
     ]),
diff --git a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py
index 25fb3a1..62ac5a5 100644
--- a/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py
+++ b/tensorflow/compiler/tf2tensorrt/python/ops/trt_ops.py
@@ -23,19 +23,19 @@
 import platform
 from tensorflow.python.framework import errors
 
-_trt_ops_so = None
+_tf_trt_so = None
 _module_lock = threading.Lock()
 
 
 def load_trt_ops():
   """Load TF-TRT op libraries so if it hasn't been loaded already."""
-  global _trt_ops_so
+  global _tf_trt_so
 
   if platform.system() == "Windows":
     raise RuntimeError("Windows platforms are not supported")
 
   with _module_lock:
-    if _trt_ops_so:
+    if _tf_trt_so:
       return
 
     try:
@@ -56,8 +56,8 @@
       from tensorflow.python.platform import resource_loader
       # pylint: enable=g-import-not-at-top
 
-      _trt_ops_so = load_library.load_op_library(
-          resource_loader.get_path_to_datafile("_trt_ops.so"))
+      _tf_trt_so = load_library.load_op_library(
+          resource_loader.get_path_to_datafile("libtftrt.so"))
     except errors.NotFoundError as e:
       no_trt_message = (
           "**** Failed to initialize TensorRT. This is either because the "
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 600a935..488d0c7 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1442,8 +1442,19 @@
   return Status::OK();
 }
 
+using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
+
 }  // namespace
 
+string InlineFunctionBodyOptions::DebugString() const {
+  return absl::StrCat("ignore_noinline=", ignore_noinline ? "true" : "false",
+                      ", override_device=", override_device ? "true" : "false",
+                      ", output_control_src=",
+                      output_control_src == OutputControlSrc::kDataOutputs
+                          ? "DataOutputs"
+                          : "ControlOutputs");
+}
+
 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
                         const InlineFunctionBodyOptions& options) {
   // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
@@ -1544,8 +1555,8 @@
 // 2) Create "output_control_node" NoOp. All nodes that have incoming control
 //    edge *from* the function call node, will be forwarded to this node.
 //
-//    We have two options for choosing which nodes will a control edge *to* the
-//    "output control node":
+//    We have two options for choosing which nodes will have a control edge *to*
+//    the "output control node":
 //       a) control returns            (`control_ret` field in FunctionDef)
 //       b) data returns               (`ret` field in FunctionDef)
 //
@@ -1574,7 +1585,8 @@
 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
                           Node* caller, const FunctionBody* fbody,
                           const InlineFunctionBodyOptions& options) {
-  VLOG(3) << "Inline function call: " << SummarizeNode(*caller);
+  VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
+          << options.DebugString() << "]";
   VLOG(4) << "Inlined function definition: " << DebugString(fbody->fdef);
 
   Status validation = ValidateInlining(caller, fbody, options);
@@ -1585,8 +1597,8 @@
   }
 
   // ------------------------------------------------------------------------ //
-  // We insert NoOps before/after inlined function body nodes, to enforce
-  // side-effects execution order.
+  // Helper functions to create `NoOp` and `Identity` nodes for auxiliary
+  // control nodes and inlined function inputs and outputs.
 
   // Add a NoOp node for function control inputs/outputs.
   const auto no_op = [&](StringPiece name) {
@@ -1710,16 +1722,17 @@
   // ------------------------------------------------------------------------ //
   // Connect output edges.
   //
-  // For i-th return node in fbody->graph, we add in "g" an identity
-  // node (outputs[i-th]). We then reconnect every incoming edge into
-  // the i-th return node to the added identity node.
+  // For i-th return node in fbody->graph, we add in "g" an identity node
+  // (outputs[i-th]). We then reconnect every incoming edge into the i-th return
+  // node to the added identity node.
   //
-  // For every data edge coming out of "callee"s i-th output, we
-  // reconnect it to the i-th identity added above.
+  // For every data edge coming out of "callee"s i-th output, we reconnect it to
+  // the i-th identity added above.
   //
-  // If "callee" is control-depended upon by any other nodes, we add a
-  // NoOp node "output_control_node". "output_control_node" depends on
-  // all identity nodes added above. And nodes previously depend on
+  // If "callee" is control-depended upon by any other nodes, we add a NoOp node
+  // "output_control_node". "output_control_node" depends on all identity nodes
+  // added above or on all control return nodes (controlled by
+  // `options.output_control_src` value). And nodes previously depend on
   // "callee" is changed to depend on "output_control_node".
   std::vector<Node*> outputs(caller->num_outputs());
   for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
@@ -1746,8 +1759,16 @@
     if (e->IsControlEdge()) {
       if (output_control_node == nullptr) {
         output_control_node = no_op("output_control_node");
-        for (Node* n : outputs) {
-          g->AddControlEdge(n, output_control_node);
+        if (options.output_control_src ==
+            InlineFunctionBodyOptions::OutputControlSource::kDataOutputs) {
+          for (Node* n : outputs) {
+            g->AddControlEdge(n, output_control_node);
+          }
+        } else {
+          for (Node* fbody_node : fbody->control_ret_nodes) {
+            Node* n = node_map[fbody_node->id()];
+            g->AddControlEdge(n, output_control_node);
+          }
         }
       }
       g->AddControlEdge(output_control_node, e->dst());
@@ -1768,7 +1789,7 @@
 }
 
 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
-                           const InlineFunctionBodyOptions& options) {
+                           const ExpandInlineFunctionsOptions& options) {
   std::vector<std::pair<Node*, const FunctionBody*>> candidates;
 
   const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
@@ -1797,8 +1818,10 @@
 
   bool inlined_any = false;
   for (const auto& p : candidates) {
-    Status inlined =
-        InlineFunctionBody(*fld, graph, p.first, p.second, options);
+    Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
+                                        p.first->IsPartitionedCall()
+                                            ? options.multi_device_options
+                                            : options.native_options);
     if (inlined.ok()) {
       inlined_any = true;
     } else {
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index b6db1cb..86b4d21 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -160,11 +160,26 @@
 FunctionBody* SymbolicGradient(const FunctionBody& f);
 
 struct InlineFunctionBodyOptions {
+  // All nodes that have incoming control edge *from* the function call node,
+  // will be forwarded to the "output control node". There are two options for
+  // choosing which nodes will have a control edge *to* the "output control
+  // node":
+  //   a) control returns            (`control_ret` field in FunctionDef)
+  //   b) data returns               (`ret` field in FunctionDef)
+  enum class OutputControlSource { kDataOutputs, kControlOutputs };
+
   // Ignore '_noinline' function attribute.
   bool ignore_noinline = false;
   // If 'true' function inlining will override explicitly specified devices
   // inside function body with the caller node device.
   bool override_device = false;
+  // For compatibility with Tensorflow v1 by default we will use data outputs.
+  // Control returns were added to Tensorflow v2 with automatic control
+  // dependencies tracking in Eager mode.
+  OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
+
+  // A human-readable debug string for this options.
+  string DebugString() const;
 };
 
 // Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
@@ -192,6 +207,48 @@
                           Node* caller, const FunctionBody* fbody,
                           const InlineFunctionBodyOptions& options);
 
+// There are three types of function calls that could be invoked during
+// *Tensorflow graph execution*:
+//
+// 1) Native function call (node.type_string() is the function name). These
+//    functions are always executed on a single-device, which is the device of
+//    the function call node.
+//
+// 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
+//    ops) can execute on multiple devices and accept DT_RESOURCE inputs that
+//    belong to different devices. This type of functions was added in
+//    Tensorflow 2.0 Eager mode, and it has control outputs to represent
+//    side-effects that must always execute (see `control_ret` in FunctionDef).
+//
+// 3) SymbolicGradient has been deprecated for a while, but we still keep it and
+//    use `native` options for inlining for compatibility.
+//
+// We need to have distinct inlining rules for compatibility with Tensorflow v1.
+//
+// There are few other places in Tensorflow that could execute functions:
+//
+// 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
+//    functions directly via function library runtime, without going through
+//    the graph.
+// 2) tf.data pipelines - also execute functions directly via function library
+//    runtime with custom executors.
+struct ExpandInlineFunctionsOptions {
+  ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
+    using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
+    multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
+  }
+
+  InlineFunctionBodyOptions native_options;
+  InlineFunctionBodyOptions multi_device_options;
+};
+
+// WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
+// workaround that will be enabled only during the function inlining unification
+// (b/126811947). Contact ezhulenev@ if you think you need it.
+// TODO(ezhulenev): Delete this function.
+bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
+                           const ExpandInlineFunctionsOptions& options);
+
 // For each node in "graph", if "lib" indicates that the node is a
 // function call, inline the function body. Returns true if at least
 // one node is inlined.
@@ -203,13 +260,11 @@
 // Function calls that can't be safely inlined into the graph (ValidateInlining
 // returns error), are ignored.
 //
-// If `override_device` is true then the inlined operations are placed on the
-// device the call node is placed on.
-bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
-                           const InlineFunctionBodyOptions& options);
-
+// TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
+// FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
+// lower_function_call.cc).
 inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
-  return ExpandInlineFunctions(lib, graph, InlineFunctionBodyOptions());
+  return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
 }
 
 // Extracts function name and attributes from `call_def` and invokes
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 72b2b14..15910af 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -801,7 +801,7 @@
 
 // Verifies that control dependencies on the caller are added as control
 // dependencies on any function calls created by inlining.
-TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithControlDeps) {
+TEST_F(FunctionLibraryRuntimeTest, ExpandInlineFunctionsWithInputControlEdges) {
   Init({test::function::XTimesTwo(), test::function::XTimesFour()});
 
   std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
@@ -885,6 +885,99 @@
   }
 }
 
+TEST_F(FunctionLibraryRuntimeTest,
+       ExpandInlineFunctionsWithOutputControlEdges) {
+  using test::function::NDef;
+  using FDH = FunctionDefHelper;
+  using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
+
+  // `add` node is not required to compute regular output `o`, but it must
+  // execute because it is in `control_ret`.
+  const FunctionDef func =
+      FDH::Create("FunctionWithControlOutputs", {"i: float"}, {"o: float"}, {},
+                  {
+                      {{"add"}, "Add", {"i", "i"}, {{"T", DT_FLOAT}}},
+                      {{"ret"}, "Mul", {"i", "i"}, {{"T", DT_FLOAT}}},
+                  },
+                  /*ret_def=*/{{"o", "ret:z:0"}},
+                  /*control_ret_def=*/{{"must_execute", "add"}});
+
+  Init({func});
+
+  // Construct a graph for the function call:
+  //
+  //   a = Arg[dtype=DT_FLOAT]
+  //   b = FunctionWithControlOutputs(a)
+  //   c = NoOp(^b)
+  //   ret = RetVal(b, ^c)
+  const auto init_graph = [this](std::unique_ptr<Graph>* g) -> void {
+    g->reset(new Graph(OpRegistry::Global()));
+
+    Scope s = Scope::NewRootScope();
+    TF_ASSERT_OK(s.graph()->AddFunctionLibrary(fdef_lib_));
+    auto a = ops::_Arg(s.WithOpName("a"), DT_FLOAT, 0);
+    auto b = test::function::Call(&s, "b", "FunctionWithControlOutputs", {a});
+    auto c = ops::NoOp(s.WithOpName("c"));
+    auto ret = ops::_Retval(s.WithOpName("ret"), b, 0);
+    s.graph()->AddControlEdge(b.node(), c.operation.node());
+    s.graph()->AddControlEdge(c.operation.node(), ret.operation.node());
+    TF_ASSERT_OK(s.ToGraph(g->get()));
+  };
+
+  std::unique_ptr<Graph> g;
+  ExpandInlineFunctionsOptions opts;
+
+  const string input_node = "Func/b/input/_0";
+  const string output_node = "Func/b/output/_1";
+  const string output_control_node = "Func/b/output_control_node/_2";
+
+  // Use data outputs as output control source.
+  opts.native_options.output_control_src = OutputControlSrc::kDataOutputs;
+
+  init_graph(&g);
+  ExpandInlineFunctions(flr0_, g.get(), opts);
+  {
+    GraphDef expected = test::function::GDef(
+        {NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
+         NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
+         NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
+         NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
+         NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
+         NDef(output_control_node, "NoOp", {"^Func/b/output/_1"}, {}),
+         NDef("c", "NoOp", {"^" + output_control_node}, {}),
+         NDef("ret", "_Retval", {output_node, "^c"},
+              {{"T", DT_FLOAT}, {"index", 0}})},
+        {func});
+
+    GraphDef actual;
+    g->ToGraphDef(&actual);
+    TF_EXPECT_GRAPH_EQ(expected, actual);
+  }
+
+  // Use control outputs as output control source.
+  opts.native_options.output_control_src = OutputControlSrc::kControlOutputs;
+
+  init_graph(&g);
+  ExpandInlineFunctions(flr0_, g.get(), opts);
+  {
+    GraphDef expected = test::function::GDef(
+        {NDef("a", "_Arg", {}, {{"T", DT_FLOAT}, {"index", 0}}),
+         NDef(input_node, "Identity", {"a"}, {{"T", DT_FLOAT}}),
+         NDef("b/add", "Add", {input_node, input_node}, {{"T", DT_FLOAT}}),
+         NDef("b/ret", "Mul", {input_node, input_node}, {{"T", DT_FLOAT}}),
+         NDef(output_node, "Identity", {"b/ret"}, {{"T", DT_FLOAT}}),
+         NDef(output_control_node, "NoOp", {"^b/add"}, {}),
+         NDef("c", "NoOp", {"^" + output_control_node}, {}),
+         NDef("ret", "_Retval", {output_node, "^c"},
+              {{"T", DT_FLOAT}, {"index", 0}})},
+        {func});
+
+    GraphDef actual;
+    g->ToGraphDef(&actual);
+    TF_EXPECT_GRAPH_EQ(expected, actual);
+  }
+}
+
 TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
   auto T = DT_INT32;
   FunctionDef stateful_func = FDH::Define(
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index f5352ec..465cddf 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -87,10 +87,10 @@
       changed = true;
     }
     if (opts_.do_function_inlining()) {
-      InlineFunctionBodyOptions inline_opts;
-      inline_opts.override_device = true;
+      ExpandInlineFunctionsOptions expand_inline_opts;
+      expand_inline_opts.native_options.override_device = true;
 
-      bool was_mutated = ExpandInlineFunctions(runtime, g, inline_opts);
+      bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts);
       if (was_mutated) {
         DumpGraph("ExpandInlineFunctions", g);
         changed = true;
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index bf00ed9..389ab39 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -680,7 +680,7 @@
 Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
                            GetFunctionSignature get_function,
                            InstantiationResult* result) {
-  VLOG(3) << "Instantiation Function: " << Print(fdef);
+  VLOG(4) << "Instantiation Function: " << Print(fdef);
 
   const OpDef& sig = fdef.signature();
   TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index a3c0230..52c8f6f 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -333,8 +333,7 @@
     name_to_node[node->name()] = node;
   }
 
-  // Traverse the graph to check if the graph is annotated with Switch outputs.
-  // Also record _Send nodes.
+  // Traverses the graph to record _Send nodes.
   // TODO(dyoon): Instead of identifying _Send node here manually, add _Send
   // to _Recv as control dependency when creating GrapplerItem.
   std::unordered_map<string, const NodeDef*> name_to_send;
@@ -343,11 +342,6 @@
       const auto& attr = node.attr();
       name_to_send[attr.at("tensor_name").s()] = &node;
     }
-
-    if (IsSwitch(node)) {
-      const auto& attr = node.attr();
-      if (attr.count(kOutputSlots) > 0) switch_outputs_annotated_ = true;
-    }
   }
 
   // To reuse _Recv ops.
@@ -709,66 +703,29 @@
   return it->second;
 }
 
-// Check Switch outputs in updated MetaGraphDef, add corresponding nodes to
-// ready queue.
-// Fallback to add all outputs if fail to find the actual output.
-bool VirtualScheduler::AddSwitchOutputsToReadyQueue(
-    const NodeDef* node, int curr_iter, const Costs::Duration& curr_time) {
-  if (node->attr().count(kOutputSlots) == 0) return false;
-
-  auto& node_state = node_map_[node];
-  const auto& slot_vector = node->attr().at(kOutputSlots);
-  if (slot_vector.list().i_size() <= curr_iter) {
-    // Sometimes we encounter infinite loop. Fall back to add all outputs.
-    return false;
-  }
-
-  int slot = slot_vector.list().i(curr_iter);
-  for (const auto& port_num_output_pair : node_state.outputs) {
-    if (port_num_output_pair.first != slot) continue;
-
-    for (auto* output_node : port_num_output_pair.second) {
-      auto& output_state = node_map_[output_node];
-      output_state.num_inputs_ready++;
-      // Execute a node as soon as all its inputs are ready. Merge nodes
-      // are special since they run as soon as one of their inputs becomes
-      // available.
-      if (output_state.num_inputs_ready == output_state.inputs.size() ||
-          IsMerge(*output_node)) {
-        // This output node is now ready.
-        output_state.time_ready = curr_time;
-        ready_nodes_->AddNode(output_node);
-        VLOG(3) << "Node " << node->name() << " iter " << curr_iter << "/"
-                << slot_vector.list().i_size() << " Add Switch output " << slot
-                << ": " << output_node->name();
-      }
-    }
-    return true;
-  }
-
-  return false;
-}
-
 void VirtualScheduler::AddOutputNodesToReadyQueue(
     const NodeDef* node, const Costs::Duration& curr_time) {
-  auto& node_state = node_map_[node];
-  int curr_iter = node_state.num_executed_times;
-  ++node_state.num_executed_times;
-
-  if (switch_outputs_annotated_) {
-    // If the graph is annotated with StepStats, reset num_inputs_ready so we
-    // can schedule the node multiple times.
-    node_state.num_inputs_ready = 0;
-
-    // For Switch node, get output branch from updated MetaGraphDef.
-    if (IsSwitch(*node) &&
-        AddSwitchOutputsToReadyQueue(node, curr_iter, curr_time))
-      return;
+  // Checks whether the Switch's output slots change over iterations.
+  int slot = -1;
+  if (IsSwitch(*node) && node->attr().count(kOutputSlots) > 0 &&
+      node->attr().at(kOutputSlots).list().i_size() > 0) {
+    slot = node->attr().at(kOutputSlots).list().i(0);
+    for (int i = 1; i < node->attr().at(kOutputSlots).list().i_size(); ++i) {
+      if (slot != node->attr().at(kOutputSlots).list().i(i)) {
+        slot = -1;
+        break;
+      }
+    }
   }
 
   // Increment num_inputs_ready of the output nodes and maybe add to ready
   // nodes.
+  auto& node_state = node_map_[node];
   for (const auto& port_num_output_pair : node_state.outputs) {
+    // If Switch is annotated and its output slots are always the same, we only
+    // schedule the slot that was executed. Otherwise, scheduler both slots.
+    if (slot >= 0 && port_num_output_pair.first != slot) continue;
+
     for (auto* output_node : port_num_output_pair.second) {
       auto& output_state = node_map_[output_node];
       output_state.num_inputs_ready++;
@@ -780,6 +737,7 @@
         // This output node is now ready.
         output_state.time_ready = curr_time;
         ready_nodes_->AddNode(output_node);
+        VLOG(3) << "  Add output: " << output_node->name();
       }
     }
   }
@@ -787,12 +745,20 @@
 
 bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
   // Update graph_costs_ and per-op costs.
-  graph_costs_ = CombineCosts(graph_costs_, node_costs);
   const NodeDef* node = ready_nodes_->GetCurrNode();
+  auto& node_state = node_map_[node];
+  // If there is annotation in the graph about execution times, we use that
+  // number, otherwise, we assume the node is executed once.
+  node_state.execution_count = node->attr().count(kExecutionCount) == 0
+                                   ? 1
+                                   : node->attr().at(kExecutionCount).i();
+  Costs total_node_costs =
+      MultiplyCosts(node_costs, node_state.execution_count);
+  graph_costs_ = CombineCosts(graph_costs_, total_node_costs);
   const string& op_name = node->op();
 
   auto& op_cost = FindOrCreateZero(op_name, &op_to_cost_);
-  op_cost = CombineCosts(op_cost, node_costs);
+  op_cost = CombineCosts(op_cost, total_node_costs);
 
   if (VLOG_IS_ON(2)) {
     // Also keep track of op counts and costs per op (with their shapes).
@@ -806,21 +772,16 @@
   }
 
   // Update node and device states.
-  auto& node_state = node_map_[node];
   auto& device = device_[node_state.device_name];
   device.nodes_executed.push_back(node);
   // Node is scheduled when the device is available AND all the inputs are
   // ready; hence, time_scheduled is time_ready if time_ready > device curr
   // time.
-  // TODO(andiryxu): Current node_state result only records the last execution.
-  // With annotated MetaGraph we can schedule a node for multiple times.
-  // Refine NodeState structure accordingly, e.g. record time_scheduled in a
-  // vector.
   node_state.time_scheduled =
       std::max(device.GetCurrTime(), node_state.time_ready);
   // Override device curr time with the time_scheduled.
   device.device_costs.execution_time = node_state.time_scheduled;
-  device.device_costs = CombineCosts(device.device_costs, node_costs);
+  device.device_costs = CombineCosts(device.device_costs, total_node_costs);
   auto curr_time = device.GetCurrTime();
   node_state.time_finished = curr_time;
 
@@ -833,7 +794,8 @@
         node_state.time_no_references[port_num] = curr_time;
       } else {
         device.memory_usage +=
-            CalculateOutputSize(node_state.output_properties, port_num);
+            CalculateOutputSize(node_state.output_properties, port_num) *
+            node_state.execution_count;
         device.nodes_in_memory.insert(std::make_pair(node, port_num));
       }
     }
@@ -841,15 +803,16 @@
 
   // Update device's per-op cost.
   auto& device_op_cost = FindOrCreateZero(op_name, &device.op_to_cost);
-  device_op_cost = CombineCosts(device_op_cost, node_costs);
+  device_op_cost = CombineCosts(device_op_cost, total_node_costs);
 
   VLOG(3) << "Op scheduled -- name: " << node->name() << ", op: " << node->op()
           << ", device: " << node->device()
+          << ", execution_count: " << node_state.execution_count
           << ", ready: " << node_state.time_ready.count()
           << ", scheduled: " << node_state.time_scheduled.count()
           << ", finished: " << node_state.time_finished.count();
 
-  // Check outputs, add ready nodes to queue.
+  // Checks outputs, and adds ready nodes to queue.
   AddOutputNodesToReadyQueue(node, curr_time);
 
   // Increment num_outputs_executed of the input nodes and maybe update memory.
@@ -866,7 +829,8 @@
       input_state.time_no_references[port] = curr_time;
       auto& input_device = device_[input_state.device_name];
       input_device.memory_usage -=
-          CalculateOutputSize(input_state.output_properties, port);
+          CalculateOutputSize(input_state.output_properties, port) *
+          node_state.execution_count;
 
       input_device.nodes_in_memory.erase(std::make_pair(input, port));
     }
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h
index cceca71..e8e1622 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.h
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.h
@@ -71,14 +71,14 @@
   // time_no_references.
 
   // How many times this node has been executed, e.g. in a while loop.
-  int num_executed_times;
+  int execution_count;
 
   NodeState() {
     num_inputs_ready = 0;
     time_ready = Costs::Duration::max();
     time_scheduled = Costs::Duration::max();
     time_finished = Costs::Duration::max();
-    num_executed_times = 0;
+    execution_count = 0;
     // Note that num_outputs_executed and time_no_references are not initialized
     // here, since we don't know the size (i.e., # outputs for this node).
   }
@@ -323,8 +323,6 @@
                           std::map<string, Costs>* op_cost);
   float Round2(const float x) const;
   bool IsPersistentNode(const NodeDef* node) const;
-  bool AddSwitchOutputsToReadyQueue(const NodeDef* node, int curr_iter,
-                                    const Costs::Duration& curr_time);
   void AddOutputNodesToReadyQueue(const NodeDef* node,
                                   const Costs::Duration& curr_time);
 
@@ -358,10 +356,6 @@
   bool track_mem_usage_snapshot_;
   const bool use_aggressive_shape_inference_;
 
-  // Whether the input graph includes Switch nodes annotated with output slots
-  // information.
-  bool switch_outputs_annotated_ = false;
-
   VirtualPlacer placer_;  // owned.
 };
 
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 3b48263..38fd380 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -873,8 +873,8 @@
     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
   }
 
-  // A simple while loop strengthened with Switch outputs.
-  void CreateGrapplerItemWithLoopSwitchOutputs() {
+  // A simple while loop strengthened with Switch outputs xxx.
+  void CreateGrapplerItemWithLoopAnnotated() {
     // Test graph produced in python using:
     /*
       with tf.Graph().as_default():
@@ -909,6 +909,12 @@
       }
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 1
+    }
+  }
 }
 node {
   name: "ones"
@@ -936,6 +942,12 @@
       }
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 1
+    }
+  }
 }
 node {
   name: "while/Enter"
@@ -965,6 +977,12 @@
       i: 10
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 1
+    }
+  }
 }
 node {
   name: "while/Enter_1"
@@ -994,6 +1012,12 @@
       i: 10
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 1
+    }
+  }
 }
 node {
   name: "while/Merge"
@@ -1012,6 +1036,12 @@
       type: DT_INT32
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/Merge_1"
@@ -1030,6 +1060,12 @@
       type: DT_FLOAT
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/Less/y"
@@ -1052,6 +1088,12 @@
       }
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/Less"
@@ -1064,11 +1106,23 @@
       type: DT_INT32
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/LoopCond"
   op: "LoopCond"
   input: "while/Less"
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/Switch"
@@ -1090,6 +1144,12 @@
     }
   }
   attr {
+    key: "_execution_count"
+    value {
+      i: 11
+    }
+  }
+  attr {
     key: "_output_slot_vector"
     value {
       list {
@@ -1128,6 +1188,12 @@
     }
   }
   attr {
+    key: "_execution_count"
+    value {
+      i: 11
+    }
+  }
+  attr {
     key: "_output_slot_vector"
     value {
       list {
@@ -1156,6 +1222,12 @@
       type: DT_INT32
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/Identity_1"
@@ -1167,6 +1239,12 @@
       type: DT_FLOAT
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/add/y"
@@ -1189,6 +1267,12 @@
       }
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/add"
@@ -1201,6 +1285,12 @@
       type: DT_INT32
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/concat/axis"
@@ -1223,6 +1313,12 @@
       }
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/concat"
@@ -1248,6 +1344,12 @@
       type: DT_INT32
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/NextIteration"
@@ -1259,6 +1361,12 @@
       type: DT_INT32
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/NextIteration_1"
@@ -1270,6 +1378,12 @@
       type: DT_FLOAT
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 10
+    }
+  }
 }
 node {
   name: "while/Exit"
@@ -1281,6 +1395,12 @@
       type: DT_INT32
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 1
+    }
+  }
 }
 node {
   name: "while/Exit_1"
@@ -1292,6 +1412,12 @@
       type: DT_FLOAT
     }
   }
+  attr {
+    key: "_execution_count"
+    value {
+      i: 1
+    }
+  }
 }
 versions {
   producer: 21
@@ -1305,6 +1431,115 @@
     grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
   }
 
+  // A simple condition graph.
+  void CreateGrapplerItemWithCondition() {
+    // Handcrafted test graph: a/Less -> Switch -> First/Second -> Merge.
+    const string gdef_ascii = R"EOF(
+node {
+  name: "a"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_FLOAT
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_FLOAT
+        tensor_shape {
+        }
+        float_val: 2.0
+      }
+    }
+  }
+}
+node {
+  name: "Less"
+  op: "Const"
+  attr {
+    key: "dtype"
+    value {
+      type: DT_BOOL
+    }
+  }
+  attr {
+    key: "value"
+    value {
+      tensor {
+        dtype: DT_BOOL
+        tensor_shape {
+        }
+        tensor_content: "\001"
+      }
+    }
+  }
+}
+node {
+  name: "Switch"
+  op: "Switch"
+  input: "a"
+  input: "Less"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+node {
+  name: "First"
+  op: "Identity"
+  input: "Switch"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+node {
+  name: "Second"
+  op: "Identity"
+  input: "Switch:1"
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+node {
+  name: "Merge"
+  op: "Merge"
+  input: "First"
+  input: "Second"
+  attr {
+    key: "N"
+    value {
+      i: 2
+    }
+  }
+  attr {
+    key: "T"
+    value {
+      type: DT_FLOAT
+    }
+  }
+}
+versions {
+  producer: 27
+})EOF";
+
+    grappler_item_.reset(new GrapplerItem);
+    CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
+                                                &grappler_item_->graph));
+    grappler_item_->id = "test_graph";
+    grappler_item_->fetch = {"Merge"};
+  }
+
   // Create a FusedBatchNorm op that has multiple output ports.
   void CreateGrapplerItemWithInterDeviceTransfers() {
     tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
@@ -2379,87 +2614,155 @@
   ValidateDependencyChain(start_times, {"while/Switch_1", "while/Exit_1"});
 }
 
-TEST_F(VirtualSchedulerTest, WhileLoopWithSwitchOutputs) {
-  // Init.
-  CreateGrapplerItemWithLoopSwitchOutputs();
-  InitScheduler();
+TEST_F(VirtualSchedulerTest, AnnotatedWhileLoop) {
+  {
+    // Init.
+    CreateGrapplerItemWithLoop();
+    InitScheduler();
 
-  // Runs the scheduler.
-  RunScheduler("");
+    // Runs the scheduler.
+    RunScheduler("");
+    Costs c = scheduler_->Summary();
 
-  RunMetadata metadata;
-  scheduler_->Summary(&metadata);
-
-  // Nodes in topological order:
-  // * const, ones
-  // * while/Enter, while/Enter_1
-  // * while/Merge, while/Merge_1
-  // * while/Less/y
-  // * while/Less
-  // * while/LoopCond
-  // * while/Switch, while/Switch_1
-  // * while/Identity, while/Identity_1, while/Exit, while/Exit_1
-  // * while/add/y, while/concat/axis
-  // * while/add, while/concat
-  // * while/NextIteration, while/NextIteration_1
-
-  int num_next_iteration = 0;
-  int num_next_iteration_1 = 0;
-  int num_exit = 0;
-  int num_exit_1 = 0;
-  int64 next_iter_start_micro;
-  int64 next_iter_1_start_micro;
-  int64 exit_start_micro;
-  int64 exit_1_start_micro;
-
-  std::unordered_map<string, int64> start_times;
-  for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
-    for (const auto& stats : device_step_stats.node_stats()) {
-      start_times[stats.node_name()] = stats.all_start_micros();
-      if (stats.node_name() == "while/NextIteration") {
-        ++num_next_iteration;
-        next_iter_start_micro = stats.all_start_micros();
-      } else if (stats.node_name() == "while/NextIteration_1") {
-        ++num_next_iteration_1;
-        next_iter_1_start_micro = stats.all_start_micros();
-      } else if (stats.node_name() == "while/Exit") {
-        ++num_exit;
-        exit_start_micro = stats.all_start_micros();
-      } else if (stats.node_name() == "while/Exit_1") {
-        ++num_exit_1;
-        exit_1_start_micro = stats.all_start_micros();
-      }
-    }
+    EXPECT_EQ(23, c.execution_time.asMicroSeconds().count());
+    // Both while/Merge and while/Merge_1 are scheduled twice.
+    EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
+    EXPECT_FALSE(c.inaccurate);
+    EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
   }
 
-  // Makes sure we run the loop body for ten times.
-  EXPECT_EQ(10, num_next_iteration);
-  EXPECT_EQ(10, num_next_iteration_1);
-  EXPECT_EQ(1, num_exit);
-  EXPECT_EQ(1, num_exit_1);
+  {
+    // Init.
+    CreateGrapplerItemWithLoopAnnotated();
+    InitScheduler();
 
-  // Start times of while/NextIteration and while/NextIteration_1 should be
-  // different, so should be those of while/Exit and while/Exit_1.
-  EXPECT_NE(next_iter_start_micro, next_iter_1_start_micro);
-  EXPECT_NE(exit_start_micro, exit_1_start_micro);
+    // Runs the scheduler.
+    RunScheduler("");
+    Costs c = scheduler_->Summary();
 
-  // Checks dependency among the nodes; no matter what scheduling mechanism we
-  // use, the scheduled ops should follow these dependency chains.
-  // We have to break the loop into two parts, identified by Switch outputs.
-  ValidateDependencyChain(
-      start_times,
-      {"Const", "while/Enter", "while/Merge", "while/Less/y", "while/Less",
-       "while/LoopCond", "while/Switch", "while/Exit"});
-  ValidateDependencyChain(start_times, {"while/Identity", "while/add/y",
-                                        "while/add", "while/NextIteration"});
-  ValidateDependencyChain(
-      start_times, {"ones", "while/Enter_1", "while/Merge_1", "while/Switch_1",
-                    "while/Exit_1"});
-  ValidateDependencyChain(start_times, {"while/Identity_1", "while/concat",
-                                        "while/NextIteration_1"});
-  ValidateDependencyChain(
-      start_times, {"while/Identity", "while/concat/axis", "while/concat"});
-  ValidateDependencyChain(start_times, {"while/Identity", "while/add"});
+    // The costs for Merge is accumulated twice for execution_count times, but
+    // since Merge's cost is minimal, we keep this behavior here.
+    EXPECT_EQ(178, c.execution_time.asMicroSeconds().count());
+    // Both while/Merge and while/Merge_1 are scheduled twice.
+    EXPECT_EQ(grappler_item_->graph.node_size() + 2, c.num_ops_total);
+    EXPECT_FALSE(c.inaccurate);
+    EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
+  }
+}
+
+TEST_F(VirtualSchedulerTest, Condition) {
+  // Without annotation.
+  {
+    // Inits.
+    CreateGrapplerItemWithCondition();
+    InitScheduler();
+
+    // Runs the scheduler.
+    RunScheduler("");
+    RunMetadata metadata;
+    Costs c = scheduler_->Summary(&metadata);
+
+    // Nodes in topological order: a/Less, Switch, First/Second, Merge.
+    int num_a = 0;
+    int num_less = 0;
+    int num_switch = 0;
+    int num_first = 0;
+    int num_second = 0;
+    int num_merge = 0;
+
+    for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
+      for (const auto& stats : device_step_stats.node_stats()) {
+        if (stats.node_name() == "a") {
+          ++num_a;
+        } else if (stats.node_name() == "Less") {
+          ++num_less;
+        } else if (stats.node_name() == "Switch") {
+          ++num_switch;
+        } else if (stats.node_name() == "First") {
+          ++num_first;
+        } else if (stats.node_name() == "Second") {
+          ++num_second;
+        } else if (stats.node_name() == "Merge") {
+          ++num_merge;
+        }
+      }
+    }
+
+    EXPECT_EQ(1, num_a);
+    EXPECT_EQ(1, num_less);
+    EXPECT_EQ(1, num_switch);
+    EXPECT_EQ(1, num_first);
+    EXPECT_EQ(1, num_second);
+    EXPECT_EQ(2, num_merge);
+
+    EXPECT_EQ(7, c.execution_time.asMicroSeconds().count());
+    // Merge is executed twice.
+    EXPECT_EQ(grappler_item_->graph.node_size() + 1, c.num_ops_total);
+    EXPECT_FALSE(c.inaccurate);
+    EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
+  }
+
+  // With annotation.
+  {
+    // Inits.
+    CreateGrapplerItemWithCondition();
+
+    // Annotates the Switch node.
+    for (auto& node : *grappler_item_->graph.mutable_node()) {
+      if (node.name() == "Switch") {
+        AttrValue attr_output_info;
+        // Adds one output slot 0 so that Second shouldn't be executed.
+        (*attr_output_info.mutable_list()).add_i(0);
+        AddNodeAttr(kOutputSlots, attr_output_info, &node);
+      }
+    }
+
+    InitScheduler();
+
+    // Runs the scheduler.
+    RunScheduler("");
+    RunMetadata metadata;
+    Costs c = scheduler_->Summary(&metadata);
+
+    // Nodes in topological order: a/Less, Switch, Merge
+    int num_a = 0;
+    int num_less = 0;
+    int num_switch = 0;
+    int num_first = 0;
+    int num_second = 0;
+    int num_merge = 0;
+
+    for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
+      for (const auto& stats : device_step_stats.node_stats()) {
+        if (stats.node_name() == "a") {
+          ++num_a;
+        } else if (stats.node_name() == "Less") {
+          ++num_less;
+        } else if (stats.node_name() == "Switch") {
+          ++num_switch;
+        } else if (stats.node_name() == "First") {
+          ++num_first;
+        } else if (stats.node_name() == "Second") {
+          ++num_second;
+        } else if (stats.node_name() == "Merge") {
+          ++num_merge;
+        }
+      }
+    }
+
+    EXPECT_EQ(1, num_a);
+    EXPECT_EQ(1, num_less);
+    EXPECT_EQ(1, num_switch);
+    EXPECT_EQ(1, num_first);
+    EXPECT_EQ(0, num_second);
+    EXPECT_EQ(1, num_merge);
+
+    EXPECT_EQ(5, c.execution_time.asMicroSeconds().count());
+    // Second is not executed.
+    EXPECT_EQ(grappler_item_->graph.node_size() - 1, c.num_ops_total);
+    EXPECT_FALSE(c.inaccurate);
+    EXPECT_EQ(0, c.num_ops_with_unknown_shapes);
+  }
 }
 
 TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index ec16c42..5b2f1e5 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -1159,8 +1159,7 @@
   FunctionLibraryRuntime* flr = ctx->mutable_function_library_runtime();
 
   // 1. Inline symbolic gradient node.
-  const InlineFunctionBodyOptions default_inline_opts;
-  const bool expanded = ExpandInlineFunctions(flr, &graph, default_inline_opts);
+  const bool expanded = ExpandInlineFunctions(flr, &graph);
   if (!expanded) {
     return errors::Internal("Failed to expand SymbolicGradient op");
   }
@@ -1182,7 +1181,7 @@
 
   // 2. Recursively inline nested function calls.
   int iteration = 0;
-  while (ExpandInlineFunctions(flr, &graph, default_inline_opts)) {
+  while (ExpandInlineFunctions(flr, &graph)) {
     if (++iteration >= 50) {
       VLOG(2) << "Break symbolic gradient inlining loop at iteration #"
               << iteration;
diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD
index c985817..5cb9544 100644
--- a/tensorflow/python/compiler/tensorrt/BUILD
+++ b/tensorflow/python/compiler/tensorrt/BUILD
@@ -67,16 +67,12 @@
     ],
 )
 
-# TODO(aaroey): this wrapper has been causing troubles of double linking, so
-# either get rid of it, or split to make it contain minimum dependencies.
 tf_py_wrap_cc(
     name = "wrap_conversion",
     srcs = ["trt_conversion.i"],
     copts = tf_copts(),
     deps = [
         "//tensorflow/compiler/tf2tensorrt:py_utils",
-        "//tensorflow/compiler/tf2tensorrt:trt_conversion",
-        "//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
         "//third_party/python_runtime:headers",
     ],
 )
diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
index 993f5a9..d0c5f25 100644
--- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
+++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
@@ -221,7 +221,7 @@
     size_500 = _construct_graph_of_size(500)
 
     # Check construction time grows approx. linearly with size.
-    e = 2  # Fudge factor to prevent flakiness.
+    e = 3  # Fudge factor to prevent flakiness.
     self.assertLess(size_500, (10 * e) * size_50)
 
   def test_no_mask_tracking(self):
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 0121410..e4371c2 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -88,7 +88,6 @@
   tensor_map = {}  # Map {reference_tensor: corresponding_tensor}
   if input_tensors is None:
     # Create placeholders to build the model on top of.
-    input_layers = []
     input_tensors = []
     for layer in model._input_layers:
       input_tensor = Input(
@@ -100,10 +99,6 @@
       # Cache newly created input layer.
       newly_created_input_layer = input_tensor._keras_history[0]
       layer_map[layer] = newly_created_input_layer
-
-    for original_input_layer, cloned_input_layer in zip(model._input_layers,
-                                                        input_layers):
-      layer_map[original_input_layer] = cloned_input_layer
   else:
     # Make sure that all input tensors come from a Keras layer.
     # If tensor comes from an input layer: cache the input layer.
diff --git a/third_party/icu/udata.patch b/third_party/icu/udata.patch
index d6d5910..2af6718 100644
--- a/third_party/icu/udata.patch
+++ b/third_party/icu/udata.patch
@@ -1,3 +1,18 @@
+--- /icu4c/source/common/unicode/uconfig.h	2018-06-19 22:34:56.000000000 -0700
++++ /ice4c/source/common/unicode/uconfig.h.new	2019-03-12 10:12:35.896095657 -0700
+@@ -55,6 +55,11 @@
+ #include "uconfig_local.h"
+ #endif
+ 
++// Tensorflow is statically linked on all platforms.
++#ifndef U_STATIC_IMPLEMENTATION
++#define U_STATIC_IMPLEMENTATION
++#endif
++
+ /**
+  * \def U_DEBUG
+  * Determines whether to include debugging code.
+
 --- /icu4c/source/common/udata.cpp.old	2018-06-19 22:34:56.000000000 -0700
 +++ /icu4c/source/common/udata.cpp	2018-10-19 14:26:09.778950855 -0700
 @@ -18,15 +18,15 @@