Roll forward of commit: If Const node in outside compilation at head has successor that's not outside compilation at head, do not move it out of TPU computation.

Reason: in some cases, the commit will cause the same Const node to be copied both in head outside compilation and tail outside compilation. And we used the same node name for head/tail outside compilation, causing duplicate nodes with the same name.

Fix is simple: use different names for head/tail outside compilation.

PiperOrigin-RevId: 242192382
diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
index 51d6444b..5eda028 100644
--- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc
@@ -599,7 +599,8 @@
 Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
                                     FunctionLibraryDefinition* fld,
                                     const string& host_graph_func_name,
-                                    Node* xla_computation_node) {
+                                    Node* xla_computation_node,
+                                    Node* pivot_node) {
   // Temporarily use "0" as "device_ordinal". It will be rewritten with the
   // correct value in a later pass. We cannot just use placeholder value here
   // because FunctionDef instantiation does not allow placeholder value for
@@ -620,7 +621,11 @@
 
   // Copy all nodes.
   std::map<const Node*, Node*> node_map;
-  node_map[host_graph->source_node()] = main_graph->source_node();
+  if (pivot_node) {
+    node_map[host_graph->source_node()] = pivot_node;
+  } else {
+    node_map[host_graph->source_node()] = main_graph->source_node();
+  }
   node_map[host_graph->sink_node()] = main_graph->sink_node();
   Status s = Status::OK();
   auto copy_node_fn = [&](const Node* n) {
@@ -673,7 +678,7 @@
 // 2) Remove control edges.
 // 3) Prune nodes that are not useful for shape inference.
 Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
-                                  Graph* host_graph,
+                                  Graph* host_graph, Node* pivot_node,
                                   FunctionLibraryDefinition* fld) {
   // Use "0" as "device_ordinal". It does not matter for shape inference.
   AttrValue device_ordinal_attr;
@@ -717,41 +722,45 @@
     for (Node* n : nodes) {
       g->RemoveNode(n);
     }
-
-    std::map<const Node*, Node*> node_map;
-    node_map[host_graph->source_node()] = g->source_node();
-    Status s;
-    auto copy_node_fn = [&](const Node* n) {
-      if (!s.ok()) {
-        return;
-      }
-
-      if (node_map.find(n) != node_map.end()) {
-        return;
-      }
-
-      NodeDef copy_def = n->def();
-      Node* copy = g->AddNode(copy_def, &s);
-      if (!s.ok()) {
-        return;
-      }
-      for (auto e : n->in_edges()) {
-        if (node_map.find(e->src()) == node_map.end()) {
-          s = errors::Internal("Cannot find node image for ",
-                               e->src()->DebugString());
-          return;
-        }
-        g->AddEdge(node_map[e->src()], e->src_output(), copy, e->dst_input());
-      }
-
-      node_map[n] = copy;
+    Node* start_node = pivot_node ? pivot_node : host_graph->source_node();
+    // Reverse DFS from send_from_host_main_graph, and stop at start_node.
+    struct Visit {
+      Node* n;
+      bool is_exiting;
     };
-    // TODO(b/77601805): consolidate copy graph functions.
-    ReverseDFSFrom(*host_graph,
-                   std::vector<const Node*>{send_from_host_main_graph},
-                   /*enter=*/nullptr, copy_node_fn, NodeComparatorID());
-    if (!s.ok()) {
-      return s;
+    std::vector<Visit> stack{{send_from_host_main_graph, false}};
+    std::map<Node*, Node*> node_map;
+    node_map[host_graph->source_node()] = g->source_node();
+    while (!stack.empty()) {
+      Visit& curr = stack.back();
+      if (curr.is_exiting) {
+        if (node_map.find(curr.n) == node_map.end()) {
+          Node* copy = g->CopyNode(curr.n);
+          if (curr.n != start_node) {
+            for (const Edge* e : curr.n->in_edges()) {
+              auto node_iter = node_map.find(e->src());
+              if (node_iter == node_map.end()) {
+                return errors::Internal("Cannot find node image for ",
+                                        e->src()->DebugString());
+              }
+              g->AddEdge(node_iter->second, e->src_output(), copy,
+                         e->dst_input());
+            }
+          }
+          node_map[curr.n] = copy;
+        }
+        stack.pop_back();
+      } else {
+        curr.is_exiting = true;
+        if (curr.n != start_node) {
+          for (const Edge* e : curr.n->in_edges()) {
+            if (node_map.find(e->src()) != node_map.end()) {
+              continue;
+            }
+            stack.push_back({e->src(), false});
+          }
+        }
+      }
     }
 
     send_from_host = node_map[send_from_host_main_graph];
@@ -1687,13 +1696,14 @@
     DumpGraphToFile("extract_outside_compilation_before", *g, fld);
   }
 
-  std::vector<string> shape_inference_graphs;
+  auto node_name_index = g->BuildNodeNameIndex();
   for (auto& iter : clusters) {
     string xla_cluster_name = iter.first;
     Node* n = iter.second.node;
     auto const& func_name_attrs = iter.second.func_name_attrs;
     auto const& host_compute_core = iter.second.host_compute_core;
 
+    std::vector<string> shape_inference_graphs;
     bool has_outside_compilation;
     string host_graph_func_name = absl::StrCat("oc_host_graph_", n->name());
     TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction(
@@ -1701,14 +1711,18 @@
         func_name_attrs, func_name_attrs.name(), host_graph_func_name,
         host_compute_core, flr, fld, &shape_inference_graphs,
         &has_outside_compilation));
-    TF_RETURN_IF_ERROR(
-        ExpandHostGraphIntoMainGraph(g, fld, host_graph_func_name, n));
-    TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
-  }
 
-  for (auto shape_inference_graph_name : shape_inference_graphs) {
-    TF_RETURN_IF_ERROR(
-        RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld));
+    string pivot_name = absl::StrCat(xla_cluster_name, "/pivot");
+    Node* pivot_node = node_name_index[pivot_name];
+    TF_RETURN_IF_ERROR(ExpandHostGraphIntoMainGraph(
+        g, fld, host_graph_func_name, n, pivot_node));
+
+    TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
+
+    for (auto shape_inference_graph_name : shape_inference_graphs) {
+      TF_RETURN_IF_ERROR(RewriteShapeInferenceGraph(shape_inference_graph_name,
+                                                    g, pivot_node, fld));
+    }
   }
 
   if (VLOG_IS_ON(4)) {