[Grappler] Inline multi-device functions using common_runtime/function implementation

There are two helper functions in function_optimizer.cc that helps to deal with V1 graphs with missing control dependencies and "problematic" semantics:

1) AddStrictInputSemantics
Adds control edges from all data inputs to enforce "strict inputs" semantics when needed for correctness.

2) AddFrameForwardingControlEdge
Adds control edge from "Enter" node pass frame information.

PiperOrigin-RevId: 246194205
diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py
index 51443d2..fa46d73 100644
--- a/tensorflow/contrib/eager/python/evaluator.py
+++ b/tensorflow/contrib/eager/python/evaluator.py
@@ -165,8 +165,15 @@
         self.__call__(example, *args, **kwargs)
       return self.all_metric_results(summary_logdir)
     # Graph construction
-    call_op = self.__call__(
-        dataset_ops.make_one_shot_iterator(dataset).get_next(), *args, **kwargs)
+    next_value = dataset_ops.make_one_shot_iterator(dataset).get_next()
+    # Function inlining destroys strict inputs semantics (function body might
+    # start execution before all inputs are ready). When iterator is exhausted
+    # and throws out of range error, function body might be partially executed.
+    # To prevent this we add an explicit control dependency from the 'get_next'.
+    with ops.control_dependencies([next_value]):
+      has_next_value = control_flow_ops.no_op(name="iterator_has_next")
+    with ops.control_dependencies([has_next_value]):
+      call_op = self.__call__(next_value, *args, **kwargs)
     init_op = self.init_variables()
     results_op = self.all_metric_results(summary_logdir)
     return (init_op, call_op, results_op)
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6c81391..fcb0d99 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -3234,6 +3234,7 @@
         ":lib",
         ":proto_text",
         ":protos_all_cc",
+        "@com_google_absl//absl/strings",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler/clusters:utils",
         "//tensorflow/core/grappler/clusters:virtual_cluster",
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 86dd2b6..bde8958 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -1501,6 +1501,7 @@
       "disable_inlining=", true_false(disable_inlining),
       ", ignore_noinline=", true_false(ignore_noinline),
       ", override_device=", true_false(ignore_noinline),
+      ", initialize_empty_device=", true_false(initialize_empty_device),
       ", keep_caller_node=", keep_caller_node_str(), ", output_control_src=",
       output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs"
                                                            : "ControlOutputs");
@@ -1699,7 +1700,10 @@
   for (Node* n : fbody->graph->op_nodes()) {
     NodeDef ndef = n->def();
 
-    if (options.override_device || ndef.device().empty()) {
+    if (options.override_device) {
+      ndef.set_device(caller->def().device());
+    }
+    if (options.initialize_empty_device && ndef.device().empty()) {
       ndef.set_device(caller->def().device());
     }
 
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index 450c974..3d071db 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -201,6 +201,13 @@
   // If 'true' function inlining will override explicitly specified devices
   // inside function body with the caller node device.
   bool override_device = false;
+  // If 'true' function inlining will fill an empty device annotation inside
+  // function body with the caller node device.
+  // TODO(ezhulenev): Remove this flag. This is mostly legacy-compatibility
+  // mode. We should never explicitly define devices when we inline multi-device
+  // functions. However we do that in 'lower_function_call_op.cc' and
+  // 'function_optimizer' for now.
+  bool initialize_empty_device = false;
   // Controls if we want to keep a node with the name as the function call node
   // in a graph after function inlining.
   KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep;
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 5d57e72..5290332 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -22,6 +22,7 @@
 #include <utility>
 #include <vector>
 
+#include "absl/strings/str_join.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/metrics.h"
 #include "tensorflow/core/common_runtime/optimization_registry.h"
@@ -607,10 +608,12 @@
     graph_->ToGraphDef(&item.graph);
 
     // It's ok to skip invalid device annotations in Grappler.
-    Status inferred_devices = item.InferDevicesFromGraph();
-    if (!inferred_devices.ok()) {
-      VLOG(3) << inferred_devices.error_message();
+    for (const Device* d : device_set_->devices()) {
+      Status added_device = item.AddDevice(d->name());
+      if (!added_device.ok()) VLOG(3) << added_device.error_message();
     }
+    VLOG(3) << "Grappler available devices: "
+            << absl::StrJoin(item.devices(), ", ");
 
     // TODO(b/114748242): Add a unit test to test this bug fix.
     if (flib_def_) {
diff --git a/tensorflow/core/common_runtime/lower_function_call_op.cc b/tensorflow/core/common_runtime/lower_function_call_op.cc
index aaa1755..4df335a 100644
--- a/tensorflow/core/common_runtime/lower_function_call_op.cc
+++ b/tensorflow/core/common_runtime/lower_function_call_op.cc
@@ -60,6 +60,7 @@
     // Tensorflow 2.0 Eager mode, and it has control outputs to represent
     // side-effects that must always execute (see `control_ret` in FunctionDef).
     inline_options.override_device = false;
+    inline_options.initialize_empty_device = true;
     inline_options.output_control_src = OutputControlSrc::kControlOutputs;
   } else {
     // Native function call (node.type_string() is the function name). These
diff --git a/tensorflow/core/common_runtime/lower_functional_ops.h b/tensorflow/core/common_runtime/lower_functional_ops.h
index 297f585..84d15a1 100644
--- a/tensorflow/core/common_runtime/lower_functional_ops.h
+++ b/tensorflow/core/common_runtime/lower_functional_ops.h
@@ -32,10 +32,8 @@
 class LowerFunctionalOpsPass : public GraphOptimizationPass {
  public:
   LowerFunctionalOpsPass() = default;
-  LowerFunctionalOpsPass(bool lower_function_calls,
-                         bool keep_lowered_nodes_fetchable)
-      : lower_function_calls_(lower_function_calls),
-        keep_lowered_nodes_fetchable_(keep_lowered_nodes_fetchable) {}
+  LowerFunctionalOpsPass(bool keep_lowered_nodes_fetchable)
+      : keep_lowered_nodes_fetchable_(keep_lowered_nodes_fetchable) {}
 
   Status Run(const GraphOptimizationPassOptions& options) override;
 
@@ -45,10 +43,6 @@
       "_lower_as_multi_device_function";
 
  private:
-  // TODO(ezhulenev): This is only required until Grappler function optimizer is
-  // not migrated to use function inlining from common_runtime.
-  bool lower_function_calls_ = true;
-
   // If defined use the value to control if functional ops must be fetchable
   // after lowering (we add IdentityN in place of all lowered nodes). If not
   // defined, this option will be inferred automatically from the graph (in
diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h
index 5abe77f..cbef1c2 100644
--- a/tensorflow/core/graph/control_flow.h
+++ b/tensorflow/core/graph/control_flow.h
@@ -25,6 +25,15 @@
 
 // Control flow info for a graph node.
 struct ControlFlowInfo {
+  // 'frame' and 'parent_frame' are pointers to:
+  //
+  // a) One of the Enter nodes corresponding to the loop body, if the node
+  //    executes inside a loop. If multiple tensors enter the while loop, it's
+  //    undefined which Enter node will be used.
+  //
+  // b) SOURCE node (node.id() == Graph::kSourceId), if the node is not inside
+  //    any of the while loops.
+
   const Node* frame = nullptr;         // frame of a node
   const Node* parent_frame = nullptr;  // parent frame of a node
   string frame_name;                   // frame name of a node
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 401f4d7..955935f 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -149,8 +149,8 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:graph_topology_view",
+        "//tensorflow/core/grappler:graph_view",
         "//tensorflow/core/grappler:grappler_item",
-        "//tensorflow/core/grappler:mutable_graph_view",
         "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/utils:functions",
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 321ba1f..630fcde 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -28,9 +28,9 @@
 #include "tensorflow/core/common_runtime/device_set.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+#include "tensorflow/core/common_runtime/lower_while_op.h"
 #include "tensorflow/core/common_runtime/placer.h"
-#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/function.pb.h"
@@ -40,59 +40,25 @@
 #include "tensorflow/core/framework/op_def.pb.h"
 #include "tensorflow/core/framework/versions.pb.h"
 #include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/grappler/graph_topology_view.h"
+#include "tensorflow/core/grappler/graph_view.h"
 #include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/mutable_graph_view.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/grappler/utils/functions.h"
-#include "tensorflow/core/grappler/utils/topological_sort.h"
-#include "tensorflow/core/grappler/utils/traversal.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
 
 namespace tensorflow {
 namespace grappler {
 namespace {
 
-// WARNING: Code in this file implicitly assumes that function input and output
-// arguments are plain tensors (tensor lists are not supported). Function inputs
-// and outputs are always expanded to a single placeholder or output tensor.
-// With this assumption, the calling node's input/output ports always match
-// function input/output arguments.
-//
-// This is guaranteed by the implementation of MakeGrapplerFunctionItem.
+constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
 
 // Mark functions that were created as a result of function specialization.
-constexpr char kGrapplerSpecializedFuncAttr[] = "_GrapplerSpecializedFunc";
-
-// Name of the attribute that defines the function for indirect function calls.
-constexpr char kFuncAttrName[] = "f";
-
-constexpr char kNoInlineAttr[] = "_noinline";
-
-// Name of the node that will have control edges from function input nodes, and
-// also used as a new destination for incoming control edges.
-constexpr char kInputsReadyNodeName[] = "inputs_ready";
-
-// Name of the node that will have control edges from function control output
-// nodes, and also used as a new source of outgoing control edges. This node
-// will guarantee that all side-effects inside function body will be executed
-// after function inlining.
-constexpr char kSideEffectsExecutedNodeName[] = "side_effects_executed";
-
-bool AttrIsTrue(const FunctionDef& func, const string& attr) {
-  return func.attr().count(attr) != 0 && func.attr().at(attr).b();
-}
-
-bool MarkedSpecialized(const FunctionDef& func) {
-  return AttrIsTrue(func, kGrapplerSpecializedFuncAttr);
-}
-
-bool MarkedNoInline(const FunctionDef& func) {
-  return AttrIsTrue(func, kNoInlineAttr);
-}
+constexpr const char* const kGrapplerSpecializedFuncAttr =
+    "_GrapplerSpecializedFunc";
 
 // There are two ways of calling a Tensorflow function:
 //
@@ -114,7 +80,7 @@
     return false;
   }
 
-  auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
+  auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
   return func_attr != nullptr && func_attr->has_func() &&
          func_attr->func().name() == func.signature().name();
 }
@@ -125,7 +91,7 @@
     return AttrSlice(func_node);
 
   } else if (IsIndirectFunctionCall(func, func_node)) {
-    auto* func_attr = AttrSlice(func_node).Find(kFuncAttrName);
+    auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
     return AttrSlice(&func_attr->func().attr());
 
   } else {
@@ -293,52 +259,21 @@
   const FunctionLibraryDefinition& function_library() const {
     return function_library_;
   }
-
-  FunctionLibraryDefinition* mutable_function_library() {
-    return &function_library_;
-  }
-
-  FunctionLibraryRuntime* mutable_function_library_runtime() {
-    InitializeFunctionLibraryRuntime();
-    return flr_;
-  }
+  FunctionLibraryDefinition& function_library() { return function_library_; }
 
   const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
   tensor_mapping() const {
     return tensor_mapping_;
   }
 
-  const absl::flat_hash_map<string, std::vector<string>>& control_overrides()
-      const {
-    return control_overrides_;
-  }
-
   const GraphView& graph_view() const { return graph_view_; }
 
-  const DeviceSet* devices() const {
-    // Create fake devices lazily only if we need a DeviceSet.
-    if (available_devices_.empty() && !item_->devices().empty()) {
-      for (const string& name : item_->devices()) {
-        auto device = absl::make_unique<FakeDevice>(name);
-        available_device_set_.AddDevice(device.get());
-        available_devices_.push_back(std::move(device));
-      }
-    }
-    return &available_device_set_;
-  }
-
   bool IsFetchNode(const string& node_name) const {
     return absl::c_any_of(item_->fetch, [&](const string& fetch) {
       return ParseTensorName(fetch).node() == node_name;
     });
   }
 
-  bool IsKeepOp(const string& node_name) const {
-    return absl::c_any_of(item_->keep_ops, [&](const string& keep_node) {
-      return keep_node == node_name;
-    });
-  }
-
   bool IsTrulyConst(const string& name) const {
     return TrulyConstNode(name) != nullptr;
   }
@@ -382,17 +317,6 @@
     }
   }
 
-  void AddControlOverrides(const NodeDef& func_node,
-                           const std::vector<string>& control_overrides) {
-    VLOG(4) << "Add control overrides: from=" << func_node.name() << " to: ["
-            << absl::StrJoin(control_overrides, ", ") << "]";
-
-    control_overrides_[func_node.name()].reserve(control_overrides.size());
-    for (const string& control_override : control_overrides) {
-      control_overrides_[func_node.name()].push_back(control_override);
-    }
-  }
-
  private:
   static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
       const GrapplerItem& item, const GraphDef& graph) {
@@ -411,39 +335,12 @@
     return const_nodes;
   }
 
-  void InitializeFunctionLibraryRuntime() {
-    if (!flr_) {
-      Env* env = Env::Default();
-      std::vector<std::unique_ptr<Device>> devices;
-      devices.push_back(absl::make_unique<FakeDevice>(env, "/device:CPU:0"));
-      device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
-      OptimizerOptions optimizer_opts;
-      optimizer_opts.set_do_function_inlining(true);
-      process_flr_.reset(new ProcessFunctionLibraryRuntime(
-          device_mgr_.get(), env, item_->graph.versions().producer(),
-          &function_library_, optimizer_opts));
-      flr_ = process_flr_->GetFLR(device_mgr_->ListDevices()[0]->name());
-    }
-  }
-
   const GrapplerItem* item_;  // must outlive this object
   RewriterConfig::Toggle opt_level_;
 
   // Function library constructed from current graph.
   FunctionLibraryDefinition function_library_;
 
-  // These fields initialized lazily only if needed.
-  std::unique_ptr<DeviceMgr> device_mgr_;
-  std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
-  FunctionLibraryRuntime* flr_ = nullptr;
-
-  // List of available `FakedDevices` (lazily initialized, see devices()).
-  mutable std::vector<std::unique_ptr<Device>> available_devices_;
-
-  // DeviceSet of fake devices (`FakeDevice`) constructed from
-  // item_.devices() (lazily initialized).
-  mutable DeviceSet available_device_set_;
-
   // Nodes that are Const and not in feed.
   absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
   // Specialized functions.
@@ -451,24 +348,15 @@
                       const FunctionSpecialization>
       specialized_functions_;
 
-  // After function inlining and specialization, the optimized graph might be in
-  // invalid state, nodes can read from non-existing function call nodes that
-  // were inlined, or they can read from output index that is no longer valid
-  // after unused outputs pruning.
+  // After function specialization, the optimized graph might be in invalid
+  // state, nodes can read from output index that is no longer valid after
+  // unused outputs pruning.
   //
   // Tensor mapping that has to be applied to the graph after all functions
   // optimizations (invalidated tensor id -> optimized graph tensor id).
   absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
       tensor_mapping_;
 
-  // When we inline a function into the optimized graph, we no longer have the
-  // function call node to anchor control dependencies. Instead we must expand
-  // each function call control output edge into multiple control dependencies
-  // to all side-effectful ops inside the function body.
-  //
-  // Invalidated function call node name -> Inlined side-effectful nodes
-  absl::flat_hash_map<string, std::vector<string>> control_overrides_;
-
   // Use graph view to find active outputs of the function caller nodes.
   GraphView graph_view_;
 
@@ -595,11 +483,10 @@
   // Keep only non-const inputs.
   std::vector<string> keep_inputs;
   const auto& inputs = specialized_func_node->input();
-  std::copy_if(inputs.begin(), inputs.end(), std::back_inserter(keep_inputs),
-               [&](const string& input) {
-                 return specialization.const_inputs.find(input) ==
-                        specialization.const_inputs.end();
-               });
+  absl::c_copy_if(inputs, std::back_inserter(keep_inputs),
+                  [&](const string& input) {
+                    return !specialization.const_inputs.contains(input);
+                  });
 
   specialized_func_node->clear_input();
   for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
@@ -613,7 +500,7 @@
     }
 
     for (const string& ctrl : specialization.control_deps) {
-      if (existing_control_deps.find(ctrl) == existing_control_deps.end()) {
+      if (!existing_control_deps.contains(ctrl)) {
         VLOG(3) << "Forward control dependency: input=" << ctrl;
         specialized_func_node->add_input(ctrl);
       }
@@ -641,8 +528,7 @@
     const string& input = func_node.input(i);
     if (IsControlInput(input)) break;
 
-    if (specialization.const_inputs.find(input) ==
-        specialization.const_inputs.end()) {
+    if (!specialization.const_inputs.contains(input)) {
       DataType dt = tin->list().type(i);
       (*attr)["Tin"].mutable_list()->add_type(dt);
     }
@@ -666,8 +552,7 @@
 
   // Keep output types of active outputs only.
   for (int i = 0; i < tout->list().type_size(); ++i) {
-    if (specialization.active_outputs.find(i) !=
-        specialization.active_outputs.end()) {
+    if (specialization.active_outputs.contains(i)) {
       DataType dt = tout->list().type(i);
       (*attr)["Tout"].mutable_list()->add_type(dt);
     }
@@ -683,7 +568,7 @@
 
   } else if (IsIndirectFunctionCall(func, func_node)) {
     auto* attr = specialized_func_node->mutable_attr();
-    (*attr)[kFuncAttrName].mutable_func()->set_name(specialized_func_name);
+    (*attr)[kFuncAttr].mutable_func()->set_name(specialized_func_name);
 
   } else {
     return errors::InvalidArgument("Unknown function call site");
@@ -853,8 +738,7 @@
   (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
 
   // Add specialized function to the library.
-  TF_RETURN_IF_ERROR(
-      ctx->mutable_function_library()->AddFunctionDef(specialized_func));
+  TF_RETURN_IF_ERROR(ctx->function_library().AddFunctionDef(specialized_func));
 
   // Add a function call node for the specialized function.
   NodeDef* specialized_func_node = optimized_graph->add_node();
@@ -881,9 +765,21 @@
 // 2) Inline function calls.
 // 3) Convert Graph back to the GraphDef.
 
+constexpr const char* const kLowerUsingSwitchMergeAttr =
+    LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
+constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
+    LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
+
 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
 using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
 
+// Checks if boolean attribute is defined and it's value is 'true'.
+bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
+  bool match;
+  Status s = GetNodeAttr(n->attrs(), attr_name, &match);
+  return s.ok() && match;
+}
+
 // Checks if string attribute is defined and it's not empty.
 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
   string match;
@@ -891,6 +787,14 @@
   return s.ok() && !match.empty();
 }
 
+bool LowerUsingSwitchMergeIsOn(const Node* n) {
+  return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
+}
+
+bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
+  return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
+}
+
 bool MarkedForTpuCompilation(const Node* n) {
   static constexpr const char* const kTpuReplicateAttr = "_tpu_replicate";
   return CheckStringAttr(n, kTpuReplicateAttr);
@@ -977,10 +881,77 @@
   return Status::OK();
 }
 
+// Validates that no dead tensor can reach function output.
+Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def,
+                             const FunctionBody& fbody) {
+  absl::flat_hash_set<const Node*> output_nodes = {fbody.ret_nodes.begin(),
+                                                   fbody.ret_nodes.end()};
+
+  // Find all nodes that can produce dead tensors.
+  std::vector<const Node*> dead_tensor_sources;
+  for (const Node* n : fbody.graph->nodes()) {
+    if (n->IsSwitch()) {
+      VLOG(4) << "Add dead tensors source. Switch node: " << n->name();
+      dead_tensor_sources.push_back(n);
+      continue;
+    }
+
+    // Native function call can also produce dead tensors if the function body
+    // has mergeless switches.
+    const FunctionDef* fdef = flib_def.Find(n->type_string());
+    if (fdef != nullptr) {
+      std::unique_ptr<FunctionBody> nested_fbody;
+
+      NameAttrList func;
+      TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(n->def(), &func));
+      TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
+                                                 &flib_def, &nested_fbody));
+
+      if (!ValidateNoDeadOutputs(flib_def, *nested_fbody).ok()) {
+        VLOG(4) << "Add dead tensors source. Function call: " << func.name()
+                << " node=" << n->name();
+        dead_tensor_sources.push_back(n);
+      }
+    }
+  }
+
+  for (const Node* dead_tensor_source : dead_tensor_sources) {
+    bool has_dead_output = false;
+
+    const auto is_output_node = [&](const Node* n) -> void {
+      const auto it = output_nodes.find(n);
+      if (it != output_nodes.end()) {
+        VLOG(4) << "Found a path to output node from dead tensor source: "
+                << dead_tensor_source->name() << " ---> " << (*it)->name();
+        has_dead_output = true;
+      }
+    };
+
+    // Stop DFS traversal at a Merge node or if already found a dead output.
+    const auto stop_traversal = [&has_dead_output](const Edge& edge) -> bool {
+      return !edge.src()->IsMerge() || has_dead_output;
+    };
+
+    DFSFrom(*fbody.graph, {dead_tensor_source}, /*enter=*/is_output_node,
+            /*leave=*/{}, NodeComparatorName{},
+            /*edge_filter=*/stop_traversal);
+
+    if (has_dead_output) {
+      return errors::Internal(
+          "Can't inline a function with dead outputs. Dead tensor source: ",
+          SummarizeNode(*dead_tensor_source));
+    }
+  }
+
+  return Status::OK();
+}
+
 // Makes an instance of FunctionBody for inlining from a Node.
 Status MakeFunctionBodyForInlining(const Node& node,
                                    const FunctionLibraryDefinition& flib_def,
                                    std::unique_ptr<FunctionBody>* fbody) {
+  VLOG(3) << "Make function body for inlining: " << SummarizeNode(node);
+
   // Finds a FunctionDef in a library and verifies that it exists.
   const auto find_fdef = [&flib_def, &node](
                              const string& name,
@@ -997,8 +968,7 @@
   // deprecated for a while, but we still support for compatibility reasons.
   if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
     NameAttrList func;
-    TF_RETURN_IF_ERROR(
-        GetNodeAttr(node.attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
+    TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), kFuncAttr, &func));
 
     const string grad = flib_def.FindGradient(func.name());
 
@@ -1029,7 +999,7 @@
           grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
 
     } else {
-      // Compute numerical gradient for a function by traversing its body.
+      // Build a gradient graph from the function body.
       const FunctionDef* fdef;
       TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
 
@@ -1047,24 +1017,112 @@
     TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
 
     VLOG(4) << "Instantiate a function call: function=" << func.name();
-    TF_RETURN_IF_ERROR(
-        FunctionDefToBodyHelper(*fdef, node.attrs(), &flib_def, fbody));
+    TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
+                                               &flib_def, fbody));
   }
 
   return Status::OK();
 }
 
+// Adds a control edges from each data input to the 'caller' to enforce strict
+// inputs semantics (all inputs are ready and alive). This is required when:
+//
+//  1) The function takes resources as inputs, and it doesn't have incoming
+//     control edges. In Tensorflow v2 context (eager mode) this should never
+//     happen, because automatic control dependencies tracking will add a
+//     control edge from the last op touching the resource. However such graphs
+//     might be produced by legacy v1 code without automatic dependency
+//     tracking. In this case strict function call semantics is required for
+//     enforcing side effects execution order.
+//
+//  2) One of the inputs is consuming Enter[is_constant=true] node, in which
+//     case it will be always alive, and potentially can lead to partial
+//     function execution after the last loop execution.
+//
+// Both of these cases would be considered illegal by construction in Tensorflow
+// V2, however we have to guarantee that graphs constructed with Tensorflow V1
+// will produce correct results.
+void AddStrictInputSemantics(Node* caller, Graph* g) {
+  const bool has_incoming_control_edges =
+      absl::c_any_of(caller->in_edges(),
+                     [](const Edge* edge) { return edge->IsControlEdge(); });
+
+  const bool has_resource_input =
+      absl::c_any_of(caller->input_types(),
+                     [](const DataType dtype) { return dtype == DT_RESOURCE; });
+
+  const bool has_constant_enter_input =
+      absl::c_any_of(caller->in_edges(), [](const Edge* edge) {
+        Node* src = edge->src();
+        return src->IsEnter() && CheckBoolAttr(src, "is_constant");
+      });
+
+  const bool requires_strict_semantics =
+      (!has_incoming_control_edges && has_resource_input) ||  // Case #1
+      (has_constant_enter_input);                             // Case #2
+  if (!requires_strict_semantics) return;
+
+  std::vector<const Node*> data_inputs;
+  data_inputs.reserve(caller->in_edges().size());
+
+  for (const Edge* edge : caller->in_edges()) {
+    if (edge->IsControlEdge()) continue;
+    data_inputs.push_back(edge->src());
+  }
+
+  VLOG(3) << "Add control edges from all data inputs to enforce strict "
+             "semantics with regard to function inputs";
+  for (const Node* node : data_inputs) {
+    g->AddControlEdge(g->FindNodeId(node->id()), caller);
+  }
+}
+
+// Adds a control edge from a frame node if the 'caller' is executing inside a
+// While loop (see control_flow.h for the 'frame' node explanation).
+void AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo>& info,
+                                   Node* caller, Graph* g) {
+  // All nodes added to the graph by v2 control flow lowering and function
+  // inlining are guaranteed to have control edges to nested function calls.
+  if (caller->id() >= info.size()) return;
+
+  // Check if a lowered node is executing inside a while loop.
+  const Node* frame = info[caller->id()].frame;
+  const bool is_in_while_loop = frame->id() != Graph::kSourceId;
+  if (!is_in_while_loop) return;
+
+  // Check if a node already has an incoming control edge. All incoming edges
+  // must be from the same execution frame (executor.cc invariant), so if we
+  // already have an incoming control edge, it's guaranteed that it will "carry"
+  // the same frame as all regular inputs.
+  const bool has_incoming_control_edges =
+      absl::c_any_of(caller->in_edges(),
+                     [](const Edge* edge) { return edge->IsControlEdge(); });
+  if (has_incoming_control_edges) return;
+
+  VLOG(3) << "Add a frame forwarding control edge: from=" << frame->name()
+          << " to=" << caller->name();
+  g->AddControlEdge(g->FindNodeId(frame->id()), caller);
+}
+
+// Inlines all function calls that are safe for inlining into the main graph.
+// Also lowers control flow V2 ops (functional If/While) into the V1 low level
+// ops (Switch/Merge/...).
+//
+// Runs a placer after inlining, to keep all nodes in a graph placed.
 Status InlineFunctionCalls(const GrapplerItem& item,
-                           const FunctionLibraryDefinition& flib_def,
-                           const GraphDef& input_graph,
-                           std::unordered_set<string>* skip_nodes,
+                           const RewriterConfig::Toggle opt_level,
                            GraphDef* output_graph) {
-  VLOG(2) << "Inline function calls";
-  Graph graph(flib_def);
+  bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE;
+  VLOG(2) << "Inline function calls: grappler_item_id=" << item.id
+          << " (aggessive_mode=" << is_aggressive << ")";
+
+  FunctionLibraryDefinition flib_def =
+      FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library());
+  std::unique_ptr<Graph> graph = absl::make_unique<Graph>(flib_def);
 
   GraphConstructorOptions graph_constructor_options;
-  TF_RETURN_IF_ERROR(
-      ConvertGraphDefToGraph(graph_constructor_options, input_graph, &graph));
+  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_constructor_options,
+                                            item.graph, graph.get()));
 
   using NodeNames = absl::flat_hash_set<absl::string_view>;
   NodeNames fetch_nodes;
@@ -1074,20 +1132,42 @@
   }
   NodeNames keep_nodes(item.keep_ops.begin(), item.keep_ops.end());
 
+  std::vector<string> inlined_function_names;
+
+  // If a function call is inside a While loop, it must have an incoming control
+  // edge, because it will be used to pass execution frame into the function
+  // body. All nodes without inputs in the function body (e.g. Const and NoOp)
+  // will be added an extra control edge from the 'input_control_node'.
+  std::vector<ControlFlowInfo> control_flow_info;
+  TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &control_flow_info));
+
   // Function inlining always adds new nodes to the end of the list, so we keep
   // iterating until we are out of nodes.
-  for (int i = 2; i < graph.num_node_ids(); ++i) {
-    Node* n = graph.FindNodeId(i);
-
+  for (int i = 2; i < graph->num_node_ids(); ++i) {
+    Node* n = graph->FindNodeId(i);
     if (n == nullptr) continue;  // deleted node
-    if (MarkedForTpuCompilation(n)) continue;
-    if (MarkedForXlaCompilation(n)) continue;
+
+    // Special case for lowering functional control flow ops. We do not rely on
+    // LowerFunctionOpsPass because in Grappler we have to be more restrictive
+    // about what type of function calls we are allowed to inline.
+    if (LowerUsingSwitchMergeIsOn(n)) {
+      VLOG(2) << "Lower functional control flow op: " << SummarizeNode(*n);
+      AddStrictInputSemantics(n, graph.get());
+      AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
+
+      if (n->type_string() == "If") {
+        TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), flib_def, false));
+      } else if (n->type_string() == "While") {
+        TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), flib_def, false));
+      }
+      continue;
+    }
 
     // Skip nodes that are not function calls.
     if (!IsFunctionCall(flib_def, *n)) continue;
-
-    // TODO(ezhulenev): Inline multi-device functions.
-    if (n->IsPartitionedCall()) continue;
+    // Skip function calls that we plan to compile later.
+    if (MarkedForTpuCompilation(n)) continue;
+    if (MarkedForXlaCompilation(n)) continue;
 
     // Function body that we will inline into the main graph. It can be a
     // function instantiation, or a gradient function instantiated from
@@ -1096,8 +1176,30 @@
     TF_RETURN_IF_ERROR(MakeFunctionBodyForInlining(*n, flib_def, &fbody));
 
     InlineFunctionBodyOptions inline_options;
-    inline_options.override_device = true;
-    inline_options.output_control_src = OutputControlSource::kDataOutputs;
+    // Ignore '_noinline' flag in aggressive mode.
+    inline_options.ignore_noinline = is_aggressive;
+
+    // Function calls created after inlining If/While ops are always inlined as
+    // multi-device functions and are not required to pass additional Grappler
+    // validations (side effects execution validation below).
+    bool force_inline_as_multi_device = LowerAsMultiDeviceFunctionIsOn(n);
+
+    // `PartitionedCall` is a TF-2.0 function call mechanism for multi-device
+    // functions:
+    // a) Function can be multi-device, and we can't override device placements.
+    // b) Automatic control dependencies tracking guarantees that all function
+    //    side-effectful nodes will have a path to one of the control outputs.
+    //    Control outputs and control edges between side-effectful (stateful)
+    //    nodes are used to explicitly mark the nodes that must execute, and to
+    //    define their execution order.
+    if (n->IsPartitionedCall() || force_inline_as_multi_device) {
+      inline_options.override_device = false;
+      inline_options.initialize_empty_device = true;
+      inline_options.output_control_src = OutputControlSource::kControlOutputs;
+    } else {
+      inline_options.override_device = true;
+      inline_options.output_control_src = OutputControlSource::kDataOutputs;
+    }
 
     if (fetch_nodes.contains(n->name())) {
       inline_options.keep_caller_node = KeepCallerNode::kFetchable;
@@ -1121,802 +1223,125 @@
       can_inline_function_call = ValidateSideEffectsExecution(
           *fbody, inline_options.output_control_src,
           has_outgoing_control_edges);
+
+      if (!can_inline_function_call.ok() &&
+          (is_aggressive || force_inline_as_multi_device)) {
+        VLOG(2) << "Ignore error: " << can_inline_function_call.error_message();
+        can_inline_function_call = Status::OK();
+      }
+    }
+    if (can_inline_function_call.ok()) {
+      can_inline_function_call = ValidateNoDeadOutputs(flib_def, *fbody);
     }
 
     if (can_inline_function_call.ok()) {
-      VLOG(2) << "Inline function call: " << SummarizeNode(*n);
-      TF_RETURN_IF_ERROR(InlineFunctionBody(graph.flib_def(), &graph, n,
+      VLOG(2) << "Inline function call node: " << n->name();
+      AddStrictInputSemantics(n, graph.get());
+      AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
+
+      TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
                                             fbody.get(), inline_options));
+      inlined_function_names.push_back(fbody->fdef.signature().name());
+
     } else {
       VLOG(2) << "Failed to inline function call node: "
-              << can_inline_function_call.error_message() << "; "
-              << SummarizeNode(*n);
+              << can_inline_function_call.error_message();
     }
   }
 
-  graph.ToGraphDef(output_graph);
-  return Status::OK();
-}
-
-// -------------------------------------------------------------------------- //
-// Inline indirect functions calls (aka PartitionedCallOp).
-//
-// When we inline indirect function calls, we instantiate the function body from
-// its FunctionDef and caller node attributes, and embed the instantiated graph
-// into the "main graph".
-//
-// In contrast to direct function calls, `PartitionedCallOp` has automatic
-// dependency tracking via input/output control edges, and we relax some of the
-// constraints that we have for direct function call inlining.
-//
-// Automatic control dependency rules:
-//
-// 1) "When a `PartitionedCallOp` function has a resource (DT_RESOURCE data
-//    type) input argument it "captures" the mutable resource.  This is
-//    implemented by automatically adding a incoming control edge from the
-//    previous side-effectful op touching that resource, and an outgoing control
-//    edge to the next side-effectful op using the same resource. This
-//    serializes the mutations of the resource to make graph execution
-//    deterministic.
-//
-// 2) All stateful ops inside a function body are guaranteed to execute in
-//    program order, this is achieved by adding control edges between stateful
-//    ops at graph construction time.
-//
-// 3) Furthermore, all ops accepting the same resource as an input are
-//    guaranteed to run in program order. This is also done by adding control
-//    edges at graph construction time. The last op touching the resource
-//    will have an outgoing control edge to all function return nodes, which
-//    will guarantee that all side effects to the resource will happen before
-//    function completion.
-//
-// Function call inlining must preserve side effect visibility:
-//
-// 1) All side effects to the captured resources, that happened before function
-//    call must be visible to the function body nodes using that resources.
-// 2) All side effects to the captured resources, that happened inside function
-//    body, must be visible to every op/function using that resource after the
-//    function call completed.
-//
-// To guarantee that these properties are preserved after inlining we:
-//
-// 1) Create "input_control" NoOp. Function call node incoming control edges
-//    will be forwarded *to* this node. Function inputs (Identity nodes) will
-//    have a control edge *from* this node. If function has no inputs, by
-//    construction it must have nodes without inputs in the function body, and
-//    in this case these nodes will have a control edge *from* this node.
-
-// 2) Create "output_control" NoOp. All nodes that have incoming control edge
-//    *from* the function call node, will be forwarded to this node. Function
-//    outputs (Identity nodes) will have a control edge *to* this node. This
-//    will guarantee that nodes that have control dependency on the function
-//    call, will observe all side-effects (guaranteed by graph construction with
-//    automatic control dependencies tracking).
-//
-// If after function instantiation we find a stateful or a dataset op inside
-// the function body, that is not reachable from any of the function outputs (or
-// if the function has no outputs), we do not inline it, because we can't
-// guarantee that these nodes will be executed in correct order (or executed at
-// all) after inlining.
-//
-// We do not try to add any extra control edges to make sure that all
-// side-effectful nodes will be executed, that should be handled at graph
-// construction time.
-
-struct MaybeDeadOutput {
-  const NodeDef* dead_tensor_src;
-  const NodeDef* output_node_dst;
-};
-
-// Finds all function outputs that might return a dead tensor. This can happen
-// if there is no `Merge` node on the path from the `Switch` node, to the
-// function output.
-Status MaybeDeadOutputs(const FunctionOptimizerContext& ctx,
-                        const GrapplerFunctionItem& item,
-                        std::vector<MaybeDeadOutput>* maybe_dead) {
-  VLOG(3) << "Find function outputs that might return dead tensors: item.id="
-          << item.id;
-  DCHECK(maybe_dead->empty()) << "Input argument must be an empty vector";
-
-  std::vector<const NodeDef*> dead_tensor_srcs;
-  for (const NodeDef& node : item.graph.node()) {
-    if (IsSwitch(node)) {
-      VLOG(4) << "Add dead tensors source. Switch node: " << node.name();
-      dead_tensor_srcs.push_back(&node);
-      continue;
-    }
-
-    // Regular (aka 'direct') function call can also produce dead tensors if
-    // the function body has mergeless switches.
-    const FunctionDef* func = ctx.function_library().Find(node.op());
-    if (func != nullptr) {
-      GrapplerFunctionItem func_item;
-      TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
-          *func, FunctionInstantiationAttributes(*func, node),
-          ctx.function_library(), ctx.graph_version(), &func_item));
-
-      std::vector<MaybeDeadOutput> func_dead_outputs;
-      TF_RETURN_IF_ERROR(MaybeDeadOutputs(ctx, func_item, &func_dead_outputs));
-
-      if (!func_dead_outputs.empty()) {
-        VLOG(4) << "Add dead tensors source. Function call: " << node.op()
-                << " node=" << node.name();
-        dead_tensor_srcs.push_back(&node);
-      }
-    }
-  }
-
-  // If we do not have dead tensor sources in the function body, it's
-  // guaranteed that all output tensors can't become dead.
-  if (dead_tensor_srcs.empty()) return Status::OK();
-
-  // Names of the function body nodes that return function output values.
-  absl::flat_hash_set<absl::string_view> output_nodes;
-  for (const auto& output_arg : item.outputs()) {
-    output_nodes.insert(output_arg.node_name);
-  }
-
-  GraphTopologyView topology_view;
-  TF_RETURN_IF_ERROR(topology_view.InitializeFromGraph(item.graph));
-
-  for (const NodeDef* dead_tensor_src : dead_tensor_srcs) {
-    DfsTraversal(topology_view, {dead_tensor_src},
-                 TraversalDirection::kFollowOutputs,
-                 // Stop traversal when reached first `Merge` node.
-                 DfsPredicates::Advance(
-                     [](const NodeDef* node) { return !IsMerge(*node); }),
-                 // If we reached output node, add MaybeDeadOutput edge.
-                 DfsCallbacks::PreOrder([&](const NodeDef* node) {
-                   if (output_nodes.find(node->name()) != output_nodes.end()) {
-                     maybe_dead->push_back({dead_tensor_src, node});
-                   }
-                 }));
-  }
-
-  return Status::OK();
-}
-
-// Returns `Status::OK()` iff `node` is an indirect function call of `func`, and
-// we know how to inline it into the main graph, otherwise returns and error
-// indicating why the function call is not inlinable.
-Status IsInlinableIndirectFunctionCall(const FunctionOptimizerContext& ctx,
-                                       const FunctionDef& func,
-                                       const NodeDef& func_node) {
-  // We inline direct function calls above, using different rules.
-  if (!IsIndirectFunctionCall(func, func_node)) {
-    return errors::InvalidArgument("Unsupported function call type: ",
-                                   SummarizeNodeDef(func_node));
-  }
-
-  if (MarkedNoInline(func)) {
-    return errors::FailedPrecondition(
-        "Can't inline function marked with '_noinline': ",
-        SummarizeNodeDef(func_node));
-  }
-
-  // Function specialization and inlining must be mutually exclusive.
-  if (MarkedSpecialized(func)) {
-    return errors::FailedPrecondition(
-        "Can't inline function created in Grappler function specialization: ",
-        SummarizeNodeDef(func_node));
-  }
-
-  // We can't inline functions that are in a fetch set, because it would
-  // invalidate fetch tensors (function call node fully inlined and doesn't
-  // exist in the optimized graph).
-  if (ctx.IsFetchNode(func_node.name())) {
-    return errors::FailedPrecondition(
-        "Can't inline function in a Grappler item fetch set: ",
-        SummarizeNodeDef(func_node));
-  }
-
-  return Status::OK();
-}
-
-// Checks that all side-effects will be executed in well defined order. We do it
-// by checking if there is a path from stateful/dataset ops to one of the
-// control output nodes.
-Status CheckThatSideEffectsWillExecute(
-    const FunctionOptimizerContext& ctx,
-    const GraphTopologyView& graph_topo_view,
-    const absl::flat_hash_set<string> control_output_nodes) {
-  // In aggressive mode we just print a warning for side-effectful nodes that
-  // might not be executed after inlining.
-  const bool aggressive = ctx.opt_level() == RewriterConfig::AGGRESSIVE;
-
-  for (const NodeDef& func_body_node : graph_topo_view.graph()->node()) {
-    const bool node_must_execute =
-        IsDataset(func_body_node) ||
-        IsStateful(func_body_node, &ctx.function_library());
-
-    // If op has DT_RESOURCE argument it will be marked as stateful, though if
-    // it only reads from that resource, it's allowed to prune it, because it
-    // can't produce any visible side-effects.
-    const bool read_only = IsReadVariableOp(func_body_node);
-
-    // _Retval marked as stateful, but we will remove it before inlining.
-    const bool retval = IsRetval(func_body_node);
-
-    if (read_only || retval || !node_must_execute) continue;
-
-    VLOG(3) << "Check that node " << func_body_node.name()
-            << " will execute after inlining.";
-    bool will_execute = false;
-
-    // Check if we reached one of the output nodes.
-    const auto callbacks = DfsCallbacks::PreOrder([&](const NodeDef* node) {
-      if (control_output_nodes.contains(node->name())) {
-        VLOG(4) << "Found a path to control output node: " << node->name();
-        will_execute = true;
-      }
-    });
-
-    // Stop if we already proved that node will execute.
-    const auto predicates = DfsPredicates::Enter(
-        [&](const NodeDef* node) { return !will_execute; });
-
-    DfsTraversal(graph_topo_view, {&func_body_node},
-                 TraversalDirection::kFollowOutputs, predicates, callbacks);
-
-    if (!will_execute) {
-      const string error_message = absl::StrCat(
-          "Can't guarantee execution of a side-effectful node, that is not "
-          "reachable from function outputs. Function body node: ",
-          SummarizeNodeDef(func_body_node));
-
-      if (aggressive) {
-        LOG(WARNING) << error_message;
-      } else {
-        return errors::Internal(error_message);
-      }
-    }
-  }
-
-  return Status::OK();
-}
-
-Status PlaceInlinedFunctionBody(
-    const NodeDef& func_node, const GrapplerFunctionItem& item,
-    const absl::flat_hash_map<absl::string_view, int>& input_args_idx,
-    FunctionOptimizerContext* ctx, GraphDef* placed_graph_def) {
-  // Control flow lowering and Placer works with a Graph object.
-  std::unique_ptr<Graph> func_body_graph =
-      absl::make_unique<Graph>(ctx->function_library());
-
-  GraphConstructorOptions opts;
-  TF_RETURN_IF_ERROR(
-      ConvertGraphDefToGraph(opts, item.graph, func_body_graph.get()));
+  VLOG(4) << "Inlined " << inlined_function_names.size()
+          << " function calls: " << absl::StrJoin(inlined_function_names, ", ");
 
   // ------------------------------------------------------------------------ //
   // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
-  // passes, so each node has a valid device assignment. Also V2 control
-  // flow ops (functional If and While) should have been lowered to V1 control
-  // flow (Switch and Merge nodes). To keep the graph valid for execution we
-  // must assign device to every inlined graph node, and also lower the control
-  // flow.
+  // passes, so each node has a valid device assignment. After function inlining
+  // and control flow V2 lowering we have to keep graph placed.
 
-  GraphOptimizationPassOptions opt_options;
-  opt_options.graph = &func_body_graph;
-  opt_options.flib_def = ctx->mutable_function_library();
+  if (inlined_function_names.empty()) {
+    VLOG(3) << "Not placing graph after function inlining"
+            << " (did not inline any of the function calls).";
 
-  // TODO(ezhulenev): Should we run full PRE_PLACEMENT pass here? And
-  // POST_PLACEMENT after placer?
-  LowerFunctionalOpsPass pass(/*lower_function_calls=*/false,
-                              /*keep_lowered_nodes_fetchable=*/false);
-  TF_RETURN_IF_ERROR(pass.Run(opt_options));
+  } else if (item.devices().empty()) {
+    // If there are no devices available for placer, we do not place graph after
+    // function inlining. This happens when Grappler is optimizing the function
+    // library, or when a graph optimized "offline", without an active runtime
+    // session, for example as a part of batch job for graph
+    // analysis/optimization. GrapplerItem instantiated from a function library
+    // doesn't have to be fully placed after all optimizations; it will be
+    // placed by the function library runtime before execution.
+    VLOG(3) << "Not placing graph after function inlining"
+            << " (device set is empty)";
 
-  // ------------------------------------------------------------------------ //
-  // Before placing the function body nodes we pin input arguments to the
-  // same device as their corresponding input nodes.
-
-  for (Node* func_body_node : func_body_graph->nodes()) {
-    const auto input_arg_idx = input_args_idx.find(func_body_node->name());
-
-    if (input_arg_idx != input_args_idx.end()) {
-      const int input_idx = input_arg_idx->second;
-      const GraphView::OutputPort output_port =
-          ctx->graph_view().GetRegularFanin({&func_node, input_idx});
-
-      const string& input_device = output_port.node->device();
-
-      if (!input_device.empty()) {
-        VLOG(3) << "Pin inlined function input node '" << func_body_node->name()
-                << "' to the '" << output_port.node->device() << "' device.";
-        func_body_node->set_requested_device(output_port.node->device());
-      } else {
-        VLOG(3) << "Inlined function input node '" << func_body_node->name()
-                << "' device is undefined.";
-      }
-    }
-  }
-
-  // ------------------------------------------------------------------------ //
-  // After placing nodes corresponding to the function inputs, we need to assign
-  // device placements to all other function body nodes.
-
-  const DeviceSet* devices = ctx->devices();
-
-  if (devices->devices().empty()) {
-    // If there are no devices available for placer, we do not place function
-    // body nodes. This happens when Grappler optimizing function library, or
-    // when graph optimized "offline", without active runtime session, for
-    // example as a part of batch job for graph analysis/optimization.
-    // GrapplerItem instantiated from a function library doesn't have to be
-    // fully placed after all optimization, it will be placed by the function
-    // library runtime before execution.
-    VLOG(3) << "Do not place instantiated function body.";
   } else {
     // If we are running in an active runtime session, Grappler will get the
     // graph after initial placing is done, and we should have devices for the
     // placer.
-    VLOG(3) << "Run placer for instantiated function body. Devices: ["
-            << absl::StrJoin(
-                   devices->devices(), ", ",
-                   [](string* out, const Device* d) { out->append(d->name()); })
-            << "]";
+    VLOG(3) << "Run placer for the graph after function inlining. "
+            << "Devices: [" << absl::StrJoin(item.devices(), ", ") << "]";
 
-    // Use function caller node device as a default for placer.
-    const Device* default_device =
-        devices->FindDeviceByName(func_node.device());
+    DeviceSet device_set;                               // does not own devices
+    std::vector<std::unique_ptr<Device>> fake_devices;  // owns fake devices
 
-    Placer placer(func_body_graph.get(), item.id, devices, default_device);
+    for (const string& name : item.devices()) {
+      auto device = absl::make_unique<FakeDevice>(name);
+      device_set.AddDevice(device.get());
+      fake_devices.push_back(std::move(device));
+    }
+
+    Placer placer(graph.get(), item.id, &device_set);
     TF_RETURN_IF_ERROR(placer.Run());
   }
 
-  // Convert Graph back to the placed GraphDef.
-  func_body_graph->ToGraphDef(placed_graph_def);
-
+  graph->ToGraphDef(output_graph);
   return Status::OK();
 }
 
-Status InlineIndirectFunctionCall(const NodeDef& func_node,
-                                  const FunctionDef& func,
-                                  FunctionOptimizerContext* ctx,
-                                  GraphDef* optimized_graph) {
-  VLOG(2) << "Inline indirect function call: " << SummarizeNodeDef(func_node);
-  VLOG(4) << "Inlined function definition: " << DebugString(func);
-  TF_RETURN_IF_ERROR(IsInlinableIndirectFunctionCall(*ctx, func, func_node));
+// Restores tensor mapping after function specialization: all inputs must be
+// connected to valid nodes.
+void RestoreTensorMapping(const FunctionOptimizerContext& ctx,
+                          GraphDef* optimized_graph) {
+  if (ctx.tensor_mapping().empty()) return;
 
-  const AttrSlice func_instantiation_attr =
-      FunctionInstantiationAttributes(func, func_node);
-
-  GrapplerFunctionItem item;
-  Status item_status = MakeGrapplerFunctionItem(func, func_instantiation_attr,
-                                                ctx->function_library(),
-                                                ctx->graph_version(), &item);
-
-  if (!item_status.ok()) {
-    return errors::InvalidArgument("Failed to inline function ", func_node.op(),
-                                   " instantiated by ", func_node.name(),
-                                   ". Error: ", item_status.error_message());
-  }
-
-  // `PartitionedCallOp` invokes functions with `allow_dead_tensors = true` to
-  // reset dead flag, and return default initialized tensors instead of a dead
-  // tensors. There is no way to express this in a regular Tensorflow graph, so
-  // we choose not to inline if a function can have dead tensors as an output
-  // position. In practice `mergeless switches` should not exists in a function
-  // body, because tf-eager will only use v2 control flow ops.
-  std::vector<MaybeDeadOutput> maybe_dead_outputs;
-  TF_RETURN_IF_ERROR(MaybeDeadOutputs(*ctx, item, &maybe_dead_outputs));
-  if (!maybe_dead_outputs.empty()) {
-    struct MaybeDeadOutputFormatter {
-      void operator()(string* out, const MaybeDeadOutput& md) const {
-        absl::StrAppend(out, SummarizeNodeDef(*md.dead_tensor_src));
-      }
-    };
-    return errors::FailedPrecondition(
-        "Can't inline function with dead outputs. Dead tensor sources (size = ",
-        maybe_dead_outputs.size(), "): ",
-        absl::StrJoin(maybe_dead_outputs, "\n", MaybeDeadOutputFormatter()));
-  }
-
-  GraphView::InputPort control_input_port =
-      ctx->graph_view().GetInputPort(func_node.name(), Graph::kControlSlot);
-  GraphView::OutputPort control_output_port =
-      ctx->graph_view().GetOutputPort(func_node.name(), Graph::kControlSlot);
-
-  // Nodes that have side effects to the captured resources.
-  std::vector<string> happens_before;
-  absl::c_transform(
-      ctx->graph_view().GetFanin(control_input_port),
-      std::back_inserter(happens_before),
-      [](const GraphView::OutputPort port) { return port.node->name(); });
-
-  VLOG(3) << "Happens before set (size = " << happens_before.size()
-          << "): " << absl::StrJoin(happens_before, ", ");
-
-  // Nodes that must observe side effects to the captured resources.
-  std::vector<string> happens_after;
-  absl::c_transform(
-      ctx->graph_view().GetFanout(control_output_port),
-      std::back_inserter(happens_after),
-      [](const GraphView::InputPort port) { return port.node->name(); });
-
-  VLOG(3) << "Happens after set (size = " << happens_after.size()
-          << "): " << absl::StrJoin(happens_after, ", ");
-
-  // Regular (data) inputs to the function call.
-  std::vector<SafeTensorId> inputs;
-  for (const string& input : func_node.input()) {
-    SafeTensorId tensor_id = ParseTensorName(input);
-    if (tensor_id.index() == Graph::kControlSlot) break;
-    inputs.push_back(tensor_id);
-  }
-
-  // Mapping from input argument node to function input position.
-  absl::flat_hash_map<absl::string_view, int> input_args_idx;
-  for (const InputArgInstantiation& input_arg : item.inputs()) {
-    const int idx = input_args_idx.size();
-    input_args_idx[input_arg.node_name] = idx;
-  }
-
-  const string prefix = strings::StrCat(func_node.name(), "/");
-
-  // ------------------------------------------------------------------------ //
-  // IMPORTANT: Actual inputs will be added to the following nodes at the very
-  // last stage, because we don't want to have invalid edges in a function body
-  // graph (control edges that depend on the nodes in the "outer" optimized
-  // graph).
-
-  // If one of the function inputs is a dead tensor, we must not execute any of
-  // the function body nodes, and let the dead tensor flag propagate through the
-  // inlined function body. We add NoOp inputs_ready node, and add control edges
-  // to it from all input nodes. Inlined function arguments (Identity nodes)
-  // will have a control dependency on it.
+  // During function specialization, we might prune unused function outputs. We
+  // need to "close the holes" that might appear in the function outputs.
   //
-  // TODO(ezhulenev): We do not need to provide this guarantee for ALL nodes in
-  // the function body. We must only ensure that we do not generate observable
-  // side effects.
+  // Example: prune unused output "f:1"
   //
-  // If the function call node has incoming control edges, we will update them
-  // to use this node as destination, to ensure side-effects execution order.
-  NodeDef* inputs_ready_node = nullptr;
-  if (func_node.input_size() > 0) {
-    inputs_ready_node = item.graph.add_node();
-    inputs_ready_node->set_op("NoOp");
-    inputs_ready_node->set_name(kInputsReadyNodeName);
-  }
-
-  // All nodes that have a control edge from the function call node, will be
-  // updated to have a control edge from 'side_effects_executed_node`. This node
-  // will have control edges from all function control outputs (see
-  // `control_ret` in FunctionDef). This a "barrier" that guarantees that all
-  // ops with side effects in the function body were executed
+  //   f = my_func[T=float](...)          f = my_func_specialized[T=float](...)
+  //   a = Identity(f:0)             ->   a = Identity(f:0)
+  //   b = Identity(f:2)                  b = Identity(f:1)
   //
-  // If the function call node has no outgoing control edges, it means that no
-  // one is interested in the function side-effect affecting captured resources.
-  //
-  // If node is in keep_ops set, it means that it must execute. This could
-  // happen if the graph is an instantiation of a function with control output.
-  NodeDef* side_effects_executed_node = nullptr;
-  if (!happens_after.empty() || ctx->IsKeepOp(func_node.name())) {
-    side_effects_executed_node = item.graph.add_node();
-    side_effects_executed_node->set_op("NoOp");
-    side_effects_executed_node->set_name(kSideEffectsExecutedNodeName);
-  }
+  // Tensor mapping (size=1): [f:2 -> f:1]
+  for (NodeDef& node : *optimized_graph->mutable_node()) {
+    for (int idx = 0; idx < node.input_size(); ++idx) {
+      TensorId input_tensor = ParseTensorName(node.input(idx));
+      if (input_tensor.index() == Graph::kControlSlot) break;
 
-  // If function executed only for the regular data outputs, it's totally safe
-  // to prune side-effects. If side-effects order is important, it must be
-  // captured at graph construction time via control edges.
-  if (item.control_output_size() > 0 && happens_after.empty()) {
-    VLOG(2) << "Function has control outputs and empty happens after set.";
-  }
-
-  // ------------------------------------------------------------------------ //
-  // If we have a node inside the function body without inputs (e.g. Const), we
-  // must attach a control dependency to it, to make sure that if a function
-  // call happens inside a loop, the node will be evaluated in correct frame.
-  //
-  // If the function call node has no inputs and no control dependencies, it
-  // means that it can't be a function call inside a loop, and we can safely
-  // insert that node without inputs into the main graph.
-  //
-  // TODO(ezhulenev): Use FrameMap (see grappler/utils/frame.h) to find out if
-  // the function is called inside a loop.
-  std::vector<string> empty_inputs_hook;
-  if (inputs_ready_node != nullptr) {
-    empty_inputs_hook.push_back(inputs_ready_node->name());
-  }
-
-  // ------------------------------------------------------------------------ //
-  // Grappler called after PRE_PLACEMENT and PLACEMENT passes, so we have to
-  // make sure that after inlining all nodes will have valid device assignment.
-
-  GraphDef placed_graph_def;
-  TF_RETURN_IF_ERROR(PlaceInlinedFunctionBody(func_node, item, input_args_idx,
-                                              ctx, &placed_graph_def));
-
-  // ------------------------------------------------------------------------ //
-  // Mapping from the '_Retval' node name to the output tensor. We build this
-  // mapping after the placement, because we might have inlined some of the
-  // functional If/While nodes (see a call to LowerFunctionalOpsPass).
-  absl::flat_hash_map<string, string> output_tensors;
-
-  for (const NodeDef& func_body_node : placed_graph_def.node()) {
-    if (!IsRetval(func_body_node)) continue;
-    if (func_body_node.input_size() != 1) {
-      return errors::Internal("_Retval node must have single input: ",
-                              SummarizeNodeDef(func_body_node));
-    }
-    output_tensors.emplace(func_body_node.name(), func_body_node.input(0));
-  }
-
-  // ------------------------------------------------------------------------ //
-  // After all nodes placed we need to prepare them for inlining into the
-  // optimized graph: turn placeholders into identities, update nodes
-  // connectivity, etc...
-
-  const auto inlined_node_name = [&func_node](const string& name) -> string {
-    return AddPrefixToNodeName(name, /*prefix=*/func_node.name());
-  };
-
-  for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
-    const string& node_name = func_body_node.name();
-
-    // Turn _Arg nodes added in place of input arguments into identity nodes.
-    const auto input_arg_idx = input_args_idx.find(node_name);
-    if (input_arg_idx != input_args_idx.end()) {
-      DCHECK_EQ(0, func_body_node.input_size());
-      func_body_node.set_op("Identity");
-      func_body_node.mutable_attr()->erase("index");
-      func_body_node.mutable_attr()->erase("shape");
-      const int input_idx = input_arg_idx->second;
-      func_body_node.add_input(inputs[input_idx].ToString());
-
-      // Add a control dependency on 'inputs_ready' node, to guarantee that all
-      // inputs are alive and all side-effects executed before function body.
-      if (inputs_ready_node) {
-        func_body_node.add_input(
-            AsControlDependency(inlined_node_name(inputs_ready_node->name())));
-      }
-    } else {
-      // Update inputs of the regular function body nodes.
-      for (string& input : *func_body_node.mutable_input()) {
-        input = inlined_node_name(input);
-      }
-
-      // Check if we need to ensure node execution in correct loop frame.
-      bool node_needs_empty_inputs_hook =
-          // We have a node to hook and node has no inputs.
-          !empty_inputs_hook.empty() && func_body_node.input_size() == 0 &&
-          // Inputs ready node will always have edge from main graph. If
-          // function call has no regular and control inputs, we will not add
-          // inputs_ready node to the function body graph.
-          node_name != kInputsReadyNodeName &&
-          // The node acting as a return barrier for execution of side effects
-          // might not have any inputs (in case function has no control outputs,
-          // but we still added it because of non-empty happens-after set), so
-          // we must make sure it's executed in correct frame.
-          (node_name != kSideEffectsExecutedNodeName ||
-           item.control_output_size() == 0);
-
-      if (node_needs_empty_inputs_hook) {
-        *func_body_node.add_input() =
-            AsControlDependency(inlined_node_name(empty_inputs_hook[0]));
-      }
-    }
-
-    // Add the function node name as a prefix 1) to node name to avoid
-    // collisions; 2) to frame name to avoid multiple LoopCond nodes in one
-    // frame after inlining.
-    TF_RETURN_IF_ERROR(
-        AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node));
-
-    // After inlining into the optimized graph, NodeDef must have all attributes
-    // defined, which is not required for a node in a FunctionDef.
-    const OpDef* op_def;
-    TF_RETURN_IF_ERROR(
-        ctx->function_library().LookUpOpDef(func_body_node.op(), &op_def));
-    AddDefaultsToNodeDef(*op_def, &func_body_node);
-  }
-
-  // ------------------------------------------------------------------------ //
-  // Check that after inlining all side-effects will be executed in well defined
-  // order. We do it by checking if there is a path from stateful/dataset ops to
-  // one of the control output nodes.
-
-  // Names of the inlined control output nodes.
-  absl::flat_hash_set<string> inlined_control_output_nodes;
-  for (const ControlOutput& control_output : item.control_outputs()) {
-    inlined_control_output_nodes.insert(
-        inlined_node_name(control_output.node_name));
-  }
-
-  // Construct a graph topology view for DFS traversals (skip invalid edges for
-  // input nodes connected to nodes in the optimized graph).
-  GraphTopologyView placed_topo_view(/*skip_invalid_edges=*/true);
-  TF_RETURN_IF_ERROR(placed_topo_view.InitializeFromGraph(placed_graph_def));
-  TF_RETURN_IF_ERROR(CheckThatSideEffectsWillExecute(
-      *ctx, placed_topo_view, inlined_control_output_nodes));
-
-  // ------------------------------------------------------------------------ //
-  // Move all the nodes to the optimized graph after successful preprocessing.
-
-  if (inputs_ready_node != nullptr) {
-    string inlined_node = inlined_node_name(inputs_ready_node->name());
-    absl::optional<int> node_idx = placed_topo_view.GetNodeIndex(inlined_node);
-
-    absl::flat_hash_set<string> input_nodes;
-    for (const string& input : func_node.input()) {
-      SafeTensorId tensor = ParseTensorName(input);
-
-      // Input node might have been a function call that was already inlined.
-      auto it = ctx->tensor_mapping().find(tensor);
-      while (it != ctx->tensor_mapping().end()) {
-        tensor = it->second;
-        it = ctx->tensor_mapping().find(tensor);
-      }
-
-      if (input_nodes.insert(tensor.node()).second) {
-        placed_graph_def.mutable_node(*node_idx)->add_input(
-            AsControlDependency(tensor.node()));
+      auto mapping = ctx.tensor_mapping().find(input_tensor);
+      if (mapping != ctx.tensor_mapping().end()) {
+        node.set_input(idx, mapping->second.ToString());
       }
     }
   }
-
-  if (side_effects_executed_node != nullptr) {
-    string inlined_node = inlined_node_name(side_effects_executed_node->name());
-    absl::optional<int> node_idx = placed_topo_view.GetNodeIndex(inlined_node);
-
-    // Add control edges from all control output nodes.
-    for (const string& node_name : inlined_control_output_nodes) {
-      placed_graph_def.mutable_node(*node_idx)->add_input(
-          AsControlDependency(node_name));
-    }
-
-    // Forward all control dependencies in the optimized graph to the new node.
-    ctx->AddControlOverrides(func_node, {inlined_node});
-  }
-
-  for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
-    // We bypass _Retval nodes and fetch tensors from `retval.input(0)`.
-    if (IsRetval(func_body_node)) continue;
-    optimized_graph->add_node()->Swap(&func_body_node);
-  }
-
-  // Indirect function call is fully inlined into the optimized graph, and we do
-  // not copy the original function call node, so we have to setup tensor
-  // mapping from old output tensors, to the outputs of inlined nodes.
-  int output_idx = 0;
-  for (const OutputArgInstantiation& output : item.outputs()) {
-    const string& output_tensor = output_tensors.at(output.node_name);
-
-    const SafeTensorId from_tensor(func_node.name(), output_idx++);
-    const SafeTensorId to_tensor = ParseTensorName(output_tensor);
-
-    const SafeTensorId inlined_to_tensor =
-        SafeTensorId(absl::StrCat(func_node.name(), "/", to_tensor.node()),
-                     to_tensor.index());
-
-    ctx->AddTensorMapping(from_tensor, inlined_to_tensor);
-  }
-
-  // If function call node was in keep_ops set, it means that we need to keep a
-  // node with the same name in the optimized graph. We forward all data
-  // consumers to inlined nodes, and we verify that the node is not in a fetch
-  // set, so it's safe to assume that the function call node is only required
-  // for a control edge source.
-  if (ctx->IsKeepOp(func_node.name())) {
-    VLOG(4) << "Add NoOp for inlined function in keep ops set.";
-    NodeDef* keep_func_node = optimized_graph->add_node();
-    keep_func_node->set_op("NoOp");
-    keep_func_node->set_name(func_node.name());
-    keep_func_node->set_device(func_node.device());
-    keep_func_node->add_input(
-        AsControlDependency(inlined_node_name(kSideEffectsExecutedNodeName)));
-  }
-
-  VLOG(3) << "Successfully inlined indirect function call: "
-          << SummarizeNodeDef(func_node);
-
-  return Status::OK();
-}
-
-// Restores graph invariants after function specialization and inlining: all
-// inputs must be connected to valid nodes.
-Status RestoreGraphInvariants(const FunctionOptimizerContext& ctx,
-                              GraphDef* optimized_graph) {
-  // After function specialization and inlining graph might be in invalid
-  // state, and some nodes can read tensors that do not exists anymore in the
-  // optimized graph: function call node was fully inlined into the graph, or
-  // output index was invalidated by the output pruning.
-
-  if (!ctx.tensor_mapping().empty()) {
-    for (NodeDef& node : *optimized_graph->mutable_node()) {
-      for (int idx = 0; idx < node.input_size(); ++idx) {
-        TensorId input_tensor = ParseTensorName(node.input(idx));
-        if (input_tensor.index() == Graph::kControlSlot) break;
-
-        auto mapping = ctx.tensor_mapping().find(input_tensor);
-        if (mapping != ctx.tensor_mapping().end()) {
-          node.set_input(idx, mapping->second.ToString());
-        }
-      }
-    }
-  }
-
-  // Function inlining instantiates function body directly into the optimized
-  // graph, and we might end up with control dependencies to the nodes that no
-  // longer exist in a graph. We need to apply control overrides to all
-  // invalidated nodes, and rewire control dependencies to the control outputs
-  // node (it's also possible to rewrite singe control edge into multiple edges
-  // to inlined side-effectful nodes).
-
-  if (!ctx.control_overrides().empty()) {
-    for (NodeDef& node : *optimized_graph->mutable_node()) {
-      // Keep track of new control inputs to the node.
-      absl::flat_hash_set<string> add_ctrl_inputs;
-
-      // Remove all invalidated control inputs.
-      for (int idx = 0; idx < node.input_size(); /* see below */) {
-        // TODO(ezhulenev): Use non-allocating TensorId after migrating
-        // `control_overrides()` to absl::flat_hash_set.
-        SafeTensorId input_tensor = ParseTensorName(node.input(idx));
-
-        auto overrides = ctx.control_overrides().find(input_tensor.node());
-        if (overrides != ctx.control_overrides().end()) {
-          // If this happens it's a bug in the function inlining.
-          if (input_tensor.index() != Graph::kControlSlot) {
-            return errors::Internal(
-                "Illegal input edge from inlined function call node");
-          }
-          // Remove control dependency to the inlined function call node.
-          node.mutable_input()->SwapElements(idx, node.input_size() - 1);
-          node.mutable_input()->RemoveLast();
-
-          // Keep track of all overrides.
-          for (const string& override : overrides->second) {
-            add_ctrl_inputs.insert(AsControlDependency(override));
-          }
-        } else {
-          // Go to the next input only if the current one was not invalidated,
-          // otherwise we need to check the swapped input as well.
-          ++idx;
-        }
-      }
-
-      // Add overrides to the node inputs.
-      for (const string& ctrl_input : add_ctrl_inputs) {
-        node.add_input(ctrl_input);
-      }
-    }
-  }
-
-  return Status::OK();
 }
 
 }  // namespace
 
 Status FunctionOptimizer::RunFunctionOptimizerPass(
-    const GrapplerItem& item, const GraphDef& graph, const int iteration,
-    std::unordered_set<string>* skip_nodes, GraphDef* optimized_graph,
-    bool* graph_has_unoptimized_function_calls) const {
-  VLOG(3) << absl::Substitute(
-      "Run function optimizer pass (iteration = $0): grappler_item_id = $1",
-      iteration, item.id);
+    const GrapplerItem& item, GraphDef* optimized_graph) const {
+  VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id;
 
   // Inline all function calls into a graph using common_runtime/function
   // implementation (see `InlineFunctionBody` function documentation).
   GraphDef graph_after_inlining;
-  TF_RETURN_IF_ERROR(InlineFunctionCalls(
-      item, FunctionLibraryDefinition(OpRegistry::Global(), graph.library()),
-      graph, skip_nodes, &graph_after_inlining));
+  TF_RETURN_IF_ERROR(
+      InlineFunctionCalls(item, opt_level_, &graph_after_inlining));
 
+  // Specialize function calls that we could not inline.
   FunctionOptimizerContext ctx(item, opt_level_, graph_after_inlining);
 
-  bool inline_gradients = options_.enable_symbolic_gradient_inlining;
-  bool inline_func = options_.enable_function_inlining;
-  bool specialize_func = options_.enable_function_specialization;
-
-  // We will process all the nodes in topological order, to correctly handle
-  // inlining of function call chains.
-  std::vector<const NodeDef*> topo_ordered_nodes;
-  TF_RETURN_IF_ERROR(
-      ComputeTopologicalOrder(graph_after_inlining, &topo_ordered_nodes));
-
-  for (const NodeDef* node : topo_ordered_nodes) {
-    // Each node optimization can modify optimized graph only by adding new
+  for (const NodeDef& node : graph_after_inlining.node()) {
+    // Function specialization can modify optimized graph only by adding new
     // nodes, we can check node size to make sure that graph was not modified.
     const int num_nodes_before = optimized_graph->node_size();
     const auto is_graph_modified = [&]() {
@@ -1925,116 +1350,50 @@
       return num_nodes > num_nodes_before;
     };
 
-    // Copy node from the `graph` to the `optimized_graph`.
-    const auto copy_node = [&]() { *optimized_graph->add_node() = *node; };
+    // Copy node from the `graph_after_inlining` to the `optimized_graph`.
+    const auto copy_node = [&]() { *optimized_graph->add_node() = node; };
 
-    // If we already failed to optimize this node during one of the previous
-    // passes, we just give up, and do not try on more time.
-    if (skip_nodes->find(node->name()) != skip_nodes->end()) {
-      VLOG(3) << "Skip optimization for node: " << node->name();
+    // Find if a node is a function call (direct or indirect).
+    const FunctionDef* func = FindFunctionCall(ctx, node);
+    if (func == nullptr) {
       copy_node();
       continue;
     }
 
-// Skip errors if optimized graph was not modified before error happened.
-#define TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(...)                     \
-  do {                                                             \
-    const Status _status = (__VA_ARGS__);                          \
-    if (TF_PREDICT_FALSE(!_status.ok() && is_graph_modified()))    \
-      return _status;                                              \
-    if (TF_PREDICT_FALSE(!_status.ok() && !is_graph_modified())) { \
-      VLOG(3) << "Skip error: " << _status.error_message();        \
-      skip_nodes->insert(node->name());                            \
-      copy_node();                                                 \
-    }                                                              \
-  } while (0)
+    const string& func_name = func->signature().name();
 
-    // ---------------------------------------------------------------------- //
-    // Inline or specialize function calls.                                //
-    // ---------------------------------------------------------------------- //
+    // Specialize it to its instantiation context if it has something worth
+    // specializing.
+    bool specialization_worthy = IsParametrized(*func) ||
+                                 HasTrulyConstInputs(node, ctx) ||
+                                 HasUnusedOutputs(node, *func, ctx);
+    // Do not specialize if function has custom gradient.
+    const string grad_func = ctx.function_library().FindGradient(func_name);
 
-    // Find if a node is a function call (direct or indirect).
-    const FunctionDef* func = FindFunctionCall(ctx, *node);
-
-    if (func != nullptr) {
-      const string& func_name = func->signature().name();
-
-      const bool is_indirect_func = IsIndirectFunctionCall(*func, *node);
-
-      // Inline indirect function call if it's inlinable.
-      if (inline_func && is_indirect_func) {
-        Status inlinable = IsInlinableIndirectFunctionCall(ctx, *func, *node);
-        if (inlinable.ok()) {
-          TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
-              InlineIndirectFunctionCall(*node, *func, &ctx, optimized_graph));
-          continue;
-        } else {
-          VLOG(2) << inlinable.error_message();
-          skip_nodes->insert(node->name());
-        }
+    if (grad_func.empty() && specialization_worthy) {
+      // TODO(ezhulenev): Specialize function call if input has a known shape.
+      // Specialize function body for its instantiation attributes and inputs.
+      Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
+      if (!status.ok() && is_graph_modified()) {
+        return status;
+      } else if (!status.ok() && !is_graph_modified()) {
+        VLOG(3) << "Skip specialization error: " << status.error_message();
+        copy_node();
       }
-
-      // Specialize it to its instantiation context if can't be inlined,
-      // and it has something worth specializing.
-      bool specialization_worthy = IsParametrized(*func) ||
-                                   HasTrulyConstInputs(*node, ctx) ||
-                                   HasUnusedOutputs(*node, *func, ctx);
-
-      // Do not specialize if function has custom gradient.
-      const string grad_func = ctx.function_library().FindGradient(func_name);
-
-      if (specialize_func && grad_func.empty() && specialization_worthy) {
-        // TODO(ezhulenev): Specialize function call if input has a known shape.
-        // Specialize function body for its instantiation attributes and inputs.
-        TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
-            SpecializeFunction(*node, *func, &ctx, optimized_graph));
-        continue;
-      } else {
-        VLOG(2) << "Skip function specialization: " << func->signature().name();
-        skip_nodes->insert(node->name());
-      }
+      continue;
+    } else {
+      VLOG(2) << "Skip function specialization: " << func->signature().name();
+      copy_node();
     }
-
-    // ---------------------------------------------------------------------- //
-    // If we reached this point, node was not handled by any of the stages
-    // (inline, specialize), simply copy the node to the optimized graph.
-    copy_node();
-
-#undef TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED
   }
 
-  TF_RETURN_IF_ERROR(RestoreGraphInvariants(ctx, optimized_graph));
+  RestoreTensorMapping(ctx, optimized_graph);
 
   // Preserve the graph version.
-  *optimized_graph->mutable_versions() = graph.versions();
-
+  *optimized_graph->mutable_versions() = item.graph.versions();
   // Prune unreachable function from the library.
-  if (options_.enable_trim_function_library) {
-    *optimized_graph->mutable_library() =
-        PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
-  } else {
-    *optimized_graph->mutable_library() = ctx.function_library().ToProto();
-  }
-
-  // Before returning we check if after single optimization pass we have more
-  // unoptimized function calls.
-  *graph_has_unoptimized_function_calls = false;
-  for (const NodeDef& node : optimized_graph->node()) {
-    // Check if we can inline symbolic gradient.
-    if (IsSymbolicGradient(node) && inline_gradients &&
-        skip_nodes->count(node.name()) == 0) {
-      *graph_has_unoptimized_function_calls = true;
-      break;
-    }
-
-    // Check if after inlining we have unoptimized function calls.
-    const FunctionDef* func = FindFunctionCall(ctx, node);
-    if (func != nullptr && !MarkedSpecialized(*func) &&
-        skip_nodes->count(node.name()) == 0) {
-      *graph_has_unoptimized_function_calls = true;
-      break;
-    }
-  }
+  *optimized_graph->mutable_library() =
+      PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
 
   return Status::OK();
 }
@@ -2047,35 +1406,7 @@
     return Status::OK();
   }
 
-  // Do not retry failed function inlining or specialization.
-  std::unordered_set<string> skip_nodes;
-  bool graph_has_unoptimized_function_calls = false;
-
-  // We'll keep running function optimizer pass until we inlined and optimized
-  // all function call nodes.
-  int iteration = 0;
-  constexpr int kMaxIterations = 3;
-
-  // 1. Run first optimizer pass with GrapplerItem.graph.
-  TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(
-      item, item.graph, 0, &skip_nodes, optimized_graph,
-      &graph_has_unoptimized_function_calls));
-
-  // 2. If after function inlining we have unoptimized function calls, we have
-  // to run function optimization pass one more time.
-  while (graph_has_unoptimized_function_calls) {
-    if (iteration++ > kMaxIterations) {
-      VLOG(1) << "Break function optimizer loop at iteration #" << iteration;
-      break;
-    }
-
-    GraphDef workspace_graph;
-    workspace_graph.Swap(optimized_graph);
-
-    TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(
-        item, workspace_graph, iteration, &skip_nodes, optimized_graph,
-        &graph_has_unoptimized_function_calls));
-  }
+  TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(item, optimized_graph));
 
   return Status::OK();
 }
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.h b/tensorflow/core/grappler/optimizers/function_optimizer.h
index ab90281..8c96bbc 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.h
@@ -41,25 +41,15 @@
  private:
   friend class FunctionOptimizerTest;
 
-  struct FunctionOptimizerOptions {
-    bool enable_function_inlining = true;
-    bool enable_function_specialization = true;
-    bool enable_symbolic_gradient_inlining = true;
-    bool enable_trim_function_library = true;
-  };
-
   // Runs a single function optimizer pass over the `graph`. All nodes that are
   // not function calls will be copied from the `graph` to the
   // `optimized_graph`. Function call nodes inlined or specialized, and
   // instantiated function body or specialized function call nodes will be added
   // to the `optimized_graph`.
-  Status RunFunctionOptimizerPass(
-      const GrapplerItem& item, const GraphDef& graph, const int iteration,
-      std::unordered_set<string>* skip_nodes, GraphDef* optimized_graph,
-      bool* graph_has_unoptimized_function_calls) const;
+  Status RunFunctionOptimizerPass(const GrapplerItem& item,
+                                  GraphDef* optimized_graph) const;
 
   RewriterConfig::Toggle opt_level_;
-  FunctionOptimizerOptions options_;
 };
 
 }  // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
index 828827a..1455399 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc
@@ -33,12 +33,7 @@
 constexpr char kDevice[] = "/job:localhost/replica:0/task:0/device:CPU:0";
 }  // namespace
 
-class FunctionOptimizerTest : public GrapplerTest {
- protected:
-  void DisableFunctionSpecialization(FunctionOptimizer* optimizer) {
-    optimizer->options_.enable_function_specialization = false;
-  }
-};
+class FunctionOptimizerTest : public GrapplerTest {};
 
 TEST_F(FunctionOptimizerTest, InlineFunction_SimpleFunction) {
   using test::function::NDef;
@@ -257,7 +252,6 @@
   using test::function::NDef;
 
   FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
-  DisableFunctionSpecialization(&optimizer);  // do not specialize noinline func
 
   const Tensor kTwo = test::AsScalar<int64>(2);
   FunctionDef func = FunctionDefHelper::Define(
@@ -513,6 +507,10 @@
   item.feed.emplace_back("a", pi);
   item.feed.emplace_back("b", pi);
 
+  const string input_x = "Func/c/input/_0";
+  const string input_y = "Func/c/input/_1";
+  const string output_z = "Func/c/output/_2";
+
   // If device set is empty, inlined function body must not be placed.
   {
     GraphDef optimized_graph;
@@ -524,14 +522,14 @@
 
          // Function body nodes are not placed, however function input nodes
          // must copy device assignment from input arguments.
-         NDef("c/inputs_ready", "NoOp", {"^a", "^b"}, {}),
-         NDef("c/x", "Identity", {"a:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
-              kDevice),
-         NDef("c/y", "Identity", {"b:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
-              kDevice),
-         NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}),
+         NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
+         NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice),
+         // TODO(ezhulenev): Currently inlined function body "implicitly placed"
+         // with a 'inline_options.initialize_empty_device' flag.
+         NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, kDevice),
+         NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, kDevice),
 
-         NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
+         NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, kDevice)},
         // Function library.
         {mul_func});
 
@@ -555,14 +553,12 @@
         {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
          NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
 
-         NDef("c/inputs_ready", "NoOp", {"^a", "^b"}, {}, kDevice),
-         NDef("c/x", "Identity", {"a:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
-              kDevice),
-         NDef("c/y", "Identity", {"b:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
-              kDevice),
-         NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}, kDevice),
+         NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
+         NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, kDevice),
+         NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, kDevice),
+         NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, kDevice),
 
-         NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
+         NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, kDevice)},
         // Function library.
         {mul_func});
 
@@ -650,54 +646,68 @@
        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
 
        // Initialize variable with one of the placeholders.
-       NDef("v", "VarHandleOp", {}, {{"dtype", DT_FLOAT}, {"shape", scalar}}),
+       NDef("v", "VarHandleOp", {}, {{"dtype", DT_FLOAT}, {"shape", scalar}},
+            kDevice),
        NDef("init_v", "AssignVariableOp", {"v", "a"}, {{"dtype", DT_FLOAT}},
             kDevice),
 
        // Function body of a first function call inlined into the graph.
-       NDef("f1/inputs_ready", "NoOp", {"^a", "^b", "^v", "^init_v"}, {},
+       NDef("Func/f1/input_control_node/_0", "NoOp", {"^init_v"}, {}, kDevice),
+
+       NDef("Func/f1/input/_1", "Identity",  // input: 'x'
+            {"a", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
+            kDevice),
+       NDef("Func/f1/input/_2", "Identity",  // input: 'y'
+            {"b", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
+            kDevice),
+       NDef("Func/f1/input/_3", "Identity",  // input: 'v'
+            {"v", "^Func/f1/input_control_node/_0"}, {{"T", DT_RESOURCE}},
             kDevice),
 
-       NDef("f1/x", "Identity", {"a:0", "^f1/inputs_ready"}, {{"T", DT_FLOAT}},
-            kDevice),
-       NDef("f1/y", "Identity", {"b:0", "^f1/inputs_ready"}, {{"T", DT_FLOAT}},
-            kDevice),
-       NDef("f1/v", "Identity", {"v:0", "^f1/inputs_ready"},
-            {{"T", DT_RESOURCE}}, kDevice),
-
-       NDef("f1/one", "Const", {"^f1/inputs_ready"},
+       NDef("f1/one", "Const", {"^Func/f1/input_control_node/_0"},
             {{"dtype", DT_FLOAT}, {"value", kOne}}, kDevice),
-       NDef("f1/add", "AssignAddVariableOp", {"f1/v", "f1/one"},
+       NDef("f1/mul", "Mul", {"Func/f1/input/_1", "Func/f1/input/_2"},
+            {{"T", DT_FLOAT}}, kDevice),
+       NDef("f1/add", "AssignAddVariableOp", {"Func/f1/input/_3", "f1/one"},
             {{"dtype", DT_FLOAT}}, kDevice),
-       NDef("f1/mul", "Mul", {"f1/x", "f1/y"}, {{"T", DT_FLOAT}}, kDevice),
 
-       NDef("f1/side_effects_executed", "NoOp", {"^f1/add"}, {}, kDevice),
+       NDef("Func/f1/output/_4", "Identity", {"f1/mul"}, {{"T", DT_FLOAT}},
+            kDevice),
+       NDef("Func/f1/output_control_node/_5", "NoOp", {"^f1/add"}, {}, kDevice),
 
        // Function body of a second function call also inlined into the graph,
-       // and input nodes read directly from the inlined nodes of the first
-       // function call.
-       NDef("f2/inputs_ready", "NoOp",
-            {"^v", "^f1/mul", "^f1/side_effects_executed"}, {}, kDevice),
+       // and input nodes read from the output nodes of the first function call.
+       NDef("Func/f2/input_control_node/_6", "NoOp",
+            {"^Func/f1/output_control_node/_5"}, {}, kDevice),
 
-       NDef("f2/x", "Identity", {"f1/mul:0", "^f2/inputs_ready"},
+       NDef("Func/f2/input/_7", "Identity",  // input: 'x'
+            {"Func/f1/output/_4", "^Func/f2/input_control_node/_6"},
             {{"T", DT_FLOAT}}, kDevice),
-       NDef("f2/y", "Identity", {"f1/mul:0", "^f2/inputs_ready"},
+       NDef("Func/f2/input/_8", "Identity",  // input: 'y'
+            {"Func/f1/output/_4", "^Func/f2/input_control_node/_6"},
             {{"T", DT_FLOAT}}, kDevice),
-       NDef("f2/v", "Identity", {"v:0", "^f2/inputs_ready"},
-            {{"T", DT_RESOURCE}}, kDevice),
+       NDef("Func/f2/input/_9", "Identity",  // input: 'v'
+            {"v", "^Func/f2/input_control_node/_6"}, {{"T", DT_RESOURCE}},
+            kDevice),
 
-       NDef("f2/one", "Const", {"^f2/inputs_ready"},
+       NDef("f2/one", "Const", {"^Func/f2/input_control_node/_6"},
             {{"dtype", DT_FLOAT}, {"value", kOne}}, kDevice),
-       NDef("f2/add", "AssignAddVariableOp", {"f2/v", "f2/one"},
+       NDef("f2/add", "AssignAddVariableOp", {"Func/f2/input/_9", "f2/one"},
             {{"dtype", DT_FLOAT}}, kDevice),
-       NDef("f2/mul", "Mul", {"f2/x", "f2/y"}, {{"T", DT_FLOAT}}, kDevice),
+       NDef("f2/mul", "Mul", {"Func/f2/input/_7", "Func/f2/input/_8"},
+            {{"T", DT_FLOAT}}, kDevice),
 
-       NDef("f2/side_effects_executed", "NoOp", {"^f2/add"}, {}, kDevice),
+       NDef("Func/f2/output/_10", "Identity", {"f2/mul"}, {{"T", DT_FLOAT}},
+            kDevice),
+       NDef("Func/f2/output_control_node/_11", "NoOp", {"^f2/add"}, {},
+            kDevice),
 
-       // Return values read directly from inlined nodes.
-       NDef("out_1", "Identity", {"f2/mul:0"}, {{"T", DT_FLOAT}}, kDevice),
+       // Return values read from inlined output nodes.
+       NDef("out_1", "Identity", {"Func/f2/output/_10"}, {{"T", DT_FLOAT}},
+            kDevice),
        NDef("out_2", "ReadVariableOp",
-            {"v", "^f1/side_effects_executed", "^f2/side_effects_executed"},
+            {"v", "^Func/f1/output_control_node/_5",
+             "^Func/f2/output_control_node/_11"},
             {{"dtype", DT_FLOAT}}, kDevice)},
 
       // Function library.
@@ -757,20 +767,22 @@
   GraphDef optimized_graph;
   TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
 
+  const string input_x = "Func/c/input/_0";
+  const string input_y = "Func/c/input/_1";
+  const string output_z = "Func/c/output/_2";
+
   GraphDef expected = test::function::GDef(
       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
 
        // Function must be inlined and `mul` node placed on a requested device,
        // and input `Identity` nodes must be colocated with their source nodes.
-       NDef("c/inputs_ready", "NoOp", {"^a", "^b"}, {}, cpu0),
-       NDef("c/x", "Identity", {"a:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
-            cpu0),
-       NDef("c/y", "Identity", {"b:0", "^c/inputs_ready"}, {{"T", DT_FLOAT}},
-            cpu1),
-       NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}, cpu1),
+       NDef(input_x, "Identity", {"a"}, {{"T", DT_FLOAT}}, cpu0),
+       NDef(input_y, "Identity", {"b"}, {{"T", DT_FLOAT}}, cpu1),
+       NDef("c/mul", "Mul", {input_x, input_y}, {{"T", DT_FLOAT}}, cpu1),
+       NDef(output_z, "Identity", {"c/mul"}, {{"T", DT_FLOAT}}, cpu1),
 
-       NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, cpu0)},
+       NDef("d", "Identity", {output_z}, {{"T", DT_FLOAT}}, cpu0)},
       // Function library.
       {mul_func});
 
@@ -809,8 +821,10 @@
       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
 
+       NDef("c", "NoOp", {}, {}, kDevice),
+
        // Call function first time.
-       NDef("f1", "PartitionedCall", {"a", "b"},
+       NDef("f1", "PartitionedCall", {"a", "b", "^c"},
             {{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
              {"Tout", DataTypeSlice{DT_FLOAT}},
              {"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
@@ -836,31 +850,49 @@
       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
        NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
 
+       NDef("c", "NoOp", {}, {}, kDevice),
+
        // Function body of a first function call inlined into the graph.
-       NDef("f1/inputs_ready", "NoOp", {"^a", "^b"}, {}, kDevice),
-       NDef("f1/x", "Identity", {"a:0", "^f1/inputs_ready"}, {{"T", DT_FLOAT}},
+       NDef("Func/f1/input_control_node/_0", "NoOp", {"^c"}, {}, kDevice),
+
+       NDef("Func/f1/input/_1", "Identity",  // input: 'x'
+            {"a", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
             kDevice),
-       NDef("f1/y", "Identity", {"b:0", "^f1/inputs_ready"}, {{"T", DT_FLOAT}},
+       NDef("Func/f1/input/_2", "Identity",  // input: 'y'
+            {"b", "^Func/f1/input_control_node/_0"}, {{"T", DT_FLOAT}},
             kDevice),
-       NDef("f1/mul", "Mul", {"f1/x", "f1/y"}, {{"T", DT_FLOAT}}, kDevice),
-       // Control input from `inputs_ready` node is added to ensure correct
-       // frame execution.
-       NDef("f1/side_effects_executed", "NoOp", {"^f1/inputs_ready"}, {},
+
+       NDef("f1/mul", "Mul", {"Func/f1/input/_1", "Func/f1/input/_2"},
+            {{"T", DT_FLOAT}}, kDevice),
+
+       NDef("Func/f1/output/_3", "Identity", {"f1/mul"}, {{"T", DT_FLOAT}},
             kDevice),
+       // Control input from `input_control_node` node is added to ensure
+       // correct frame execution.
+       NDef("Func/f1/output_control_node/_4", "NoOp",
+            {"^Func/f1/input_control_node/_0"}, {}, kDevice),
 
        // Function body of a second function call also inlined into the graph,
-       // and input nodes read directly from the inlined nodes of the first
+       // and input nodes read directly from the output nodes of the first
        // function call, and control dependency edge removed.
-       NDef("f2/inputs_ready", "NoOp", {"^f1/mul", "^f1/side_effects_executed"},
-            {}, kDevice),
-       NDef("f2/x", "Identity", {"f1/mul:0", "^f2/inputs_ready"},
-            {{"T", DT_FLOAT}}, kDevice),
-       NDef("f2/y", "Identity", {"f1/mul:0", "^f2/inputs_ready"},
-            {{"T", DT_FLOAT}}, kDevice),
-       NDef("f2/mul", "Mul", {"f2/x", "f2/y"}, {{"T", DT_FLOAT}}, kDevice),
+       NDef("Func/f2/input_control_node/_5", "NoOp",
+            {"^Func/f1/output_control_node/_4"}, {}, kDevice),
 
-       // Return directly from inlined node of f2.
-       NDef("out", "Identity", {"f2/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
+       NDef("Func/f2/input/_6", "Identity",
+            {"Func/f1/output/_3", "^Func/f2/input_control_node/_5"},
+            {{"T", DT_FLOAT}}, kDevice),
+       NDef("Func/f2/input/_7", "Identity",
+            {"Func/f1/output/_3", "^Func/f2/input_control_node/_5"},
+            {{"T", DT_FLOAT}}, kDevice),
+
+       NDef("f2/mul", "Mul", {"Func/f2/input/_6", "Func/f2/input/_7"},
+            {{"T", DT_FLOAT}}, kDevice),
+       NDef("Func/f2/output/_8", "Identity", {"f2/mul"}, {{"T", DT_FLOAT}},
+            kDevice),
+
+       // Return directly from output node of f2.
+       NDef("out", "Identity", {"Func/f2/output/_8"}, {{"T", DT_FLOAT}},
+            kDevice)},
 
       // Function library.
       {mul_func});
@@ -1003,22 +1035,24 @@
        NDef("b", "Placeholder", {}, {{"dtype", DT_BOOL}}, kDevice),
 
        // Function body of a first function call inlined into the graph.
-       NDef("fn/inputs_ready", "NoOp", {"^a", "^b"}, {}, kDevice),
-       NDef("fn/x", "Identity", {"a:0", "^fn/inputs_ready"}, {{"T", DT_FLOAT}},
-            kDevice),
-       NDef("fn/cond", "Identity", {"b:0", "^fn/inputs_ready"},
-            {{"T", DT_BOOL}}, kDevice),
-       NDef("fn/switch", "Switch", {"fn/x:0", "fn/cond:0"}, {{"T", DT_FLOAT}},
-            kDevice),
-       NDef("fn/if_false", "Identity", {"fn/switch:0"}, {{"T", DT_FLOAT}},
+       NDef("Func/fn/input/_0", "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
+       NDef("Func/fn/input/_1", "Identity", {"b"}, {{"T", DT_BOOL}}, kDevice),
+
+       NDef("fn/switch", "Switch", {"Func/fn/input/_0", "Func/fn/input/_1"},
+            {{"T", DT_FLOAT}}, kDevice),
+       NDef("fn/if_false", "Identity", {"fn/switch"}, {{"T", DT_FLOAT}},
             kDevice),
        NDef("fn/if_true", "Identity", {"fn/switch:1"}, {{"T", DT_FLOAT}},
             kDevice),
-       NDef("fn/merge", "Merge", {"fn/if_false:0", "fn/if_true:0"},
+       NDef("fn/merge", "Merge", {"fn/if_false", "fn/if_true"},
             {{"T", DT_FLOAT}, {"N", 2}}, kDevice),
 
-       // Return directly from inlined node.
-       NDef("out", "Identity", {"fn/merge:0"}, {{"T", DT_FLOAT}}, kDevice)},
+       NDef("Func/fn/output/_2", "Identity", {"fn/merge"}, {{"T", DT_FLOAT}},
+            kDevice),
+
+       // Return directly from inlined function output node.
+       NDef("out", "Identity", {"Func/fn/output/_2"}, {{"T", DT_FLOAT}},
+            kDevice)},
 
       // Function library.
       {no_dead_outputs});
@@ -1088,22 +1122,25 @@
       {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
 
        // Inlined inputs of `b` node.
-       NDef("b/inputs_ready", "NoOp", {"^a"}, {}, kDevice),
-       NDef("b/x", "Identity", {"a:0", "^b/inputs_ready"}, {{"T", DT_FLOAT}},
-            kDevice),
+       NDef("Func/b/input/_0", "Identity", {"a"}, {{"T", DT_FLOAT}}, kDevice),
 
        // Inlined inputs of `square` node inside inlined `MySquare` function.
-       NDef("b/square/inputs_ready", "NoOp", {"^b/x"}, {}, kDevice),
-       NDef("b/square/x", "Identity", {"b/x:0", "^b/square/inputs_ready"},
+       NDef("Func/b/square/input/_2", "Identity", {"Func/b/input/_0"},
             {{"T", DT_FLOAT}}, kDevice),
-       NDef("b/square/y", "Identity", {"b/x:0", "^b/square/inputs_ready"},
+       NDef("Func/b/square/input/_3", "Identity", {"Func/b/input/_0"},
             {{"T", DT_FLOAT}}, kDevice),
 
        // Inlined mul node from the `MyMul` function.
-       NDef("b/square/mul", "Mul", {"b/square/x", "b/square/y"},
+       NDef("b/square/mul", "Mul",
+            {"Func/b/square/input/_2", "Func/b/square/input/_3"},
             {{"T", DT_FLOAT}}, kDevice),
 
-       NDef("c", "Identity", {"b/square/mul:0"}, {{"T", DT_FLOAT}}, kDevice)},
+       NDef("Func/b/square/output/_4", "Identity", {"b/square/mul"},
+            {{"T", DT_FLOAT}}, kDevice),
+       NDef("Func/b/output/_1", "Identity", {"Func/b/square/output/_4"},
+            {{"T", DT_FLOAT}}, kDevice),
+
+       NDef("c", "Identity", {"Func/b/output/_1"}, {{"T", DT_FLOAT}}, kDevice)},
       // Function library.
       {mul_func});
 
@@ -1155,7 +1192,7 @@
            }},
       },
       /* Mapping between function returns and function node outputs. */
-      {{"z", "if_node:output:0"}});
+      {{"z", "if_node:output:0"}}, {{"side_effect", "if_node"}});
 
   // Build a computation graph for:
   //   is_add: bool
@@ -1179,7 +1216,7 @@
              {"f", FDH::FunctionRef("AddOrMul")}},
             kDevice),
 
-       NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, kDevice)},
+       NDef("d", "Identity", {"c", "^c"}, {{"T", DT_FLOAT}}, kDevice)},
       // Function library.
       {add_or_mul_func, add_func, mul_func});
 
@@ -1888,10 +1925,10 @@
 TEST_F(FunctionOptimizerTest, PruningUselessLibraryFunctions) {
   using test::function::NDef;
   FunctionOptimizer optimizer(RewriterConfig::DEFAULT);
-  DisableFunctionSpecialization(&optimizer);
   auto func = test::function::XTimesTwo();
   (*func.mutable_attr())["_noinline"].set_b(true);
   GrapplerItem item;
+  item.id = "test_graph";
   item.graph = test::function::GDef(
       {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, "/device:CPU:0"),
        NDef("y", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, "/device:CPU:0"),
@@ -1906,8 +1943,9 @@
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
 
-  EXPECT_EQ(output.library().function().size(), 1);
-  EXPECT_EQ(output.library().function(0).signature().name(), "XTimesTwo");
+  ASSERT_EQ(output.library().function().size(), 1);
+  EXPECT_EQ(output.library().function(0).signature().name(),
+            "XTimesTwo_specialized_for_y_at_test_graph");
 }
 
 }  // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 6e68a53..e07fe78 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -15,6 +15,7 @@
 
 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
 
+#include "absl/strings/str_join.h"
 #include "absl/strings/substitute.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/function.pb.h"
@@ -775,6 +776,8 @@
     Status added_device = item.AddDevice(d->name());
     if (!added_device.ok()) VLOG(3) << added_device.error_message();
   }
+  VLOG(3) << "Grappler available devices: "
+          << absl::StrJoin(item.devices(), ", ");
 
   // Add fetches so that the graph can be pruned.
   item.fetch.swap(ret_node_names);