Make cond_v2 If op lowering work in a defun + eager.

Prior to this change, the lowering pass assumed that the If op
functions would be available in the If op's graph. If the If op is
defined in a defun and then called via eager execution, the functions
will be in the eager context, but not in the defun's graph. This
change makes the lowering pass correctly use the function library
passed in by the caller via GraphOptimizationPassOptions.

PiperOrigin-RevId: 215271990
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index dfce7c2..a02084f 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -38,11 +38,12 @@
  public:
   enum Branch { kElseBranch = 0, kThenBranch = 1 };
 
-  // Create a CondBuilder to create the lowering of If op.  that has then and
+  // Create a CondBuilder to create the lowered form of `if_op` with then and
   // else functions named `then_fn_name` and `else_fn_name` respectively in the
-  // given graph.
+  // `graph`. The functions should be available in `flib`.
   CondBuilder(Node* if_op, const string& then_fn_name,
-              const string& else_fn_name, Graph* graph);
+              const string& else_fn_name, const FunctionLibraryDefinition& flib,
+              Graph* graph);
 
   // Constructs the basic conditional control flow using switch and merge nodes.
   Status CreatePivotNodes();
@@ -89,6 +90,7 @@
   Node* then_call_node_;
   Node* else_call_node_;
   Graph* graph_;
+  const FunctionLibraryDefinition& flib_;
   string name_;
 
   NodeBuilder then_call_builder_;
@@ -96,9 +98,11 @@
 };
 
 CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
-                         const string& else_fn_name, Graph* graph)
+                         const string& else_fn_name,
+                         const FunctionLibraryDefinition& flib, Graph* graph)
     : if_op_(if_op),
       graph_(graph),
+      flib_(flib),
       name_(if_op->name()),
       then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
       else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
@@ -193,15 +197,15 @@
   return Status::OK();
 }
 
-Status InlineCallInGraph(Node* n, Graph* g) {
-  const auto& lib = g->flib_def();
-  const FunctionDef* fdef = lib.Find(n->type_string());
+Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
+                         Graph* g) {
+  const FunctionDef* fdef = flib.Find(n->type_string());
   CHECK(fdef != nullptr);
   FunctionBody* fbody;
   TF_RETURN_IF_ERROR(
-      FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
-                              [&lib](const string& op, const OpDef** sig) {
-                                return lib.LookUpOpDef(op, sig);
+      FunctionDefToBodyHelper(*fdef, n->attrs(), &flib,
+                              [&flib](const string& op, const OpDef** sig) {
+                                return flib.LookUpOpDef(op, sig);
                               },
                               &fbody));
   // TODO(jpienaar): Improve this interface to make the need to delete it
@@ -219,8 +223,8 @@
 }
 
 Status CondBuilder::InlineCallNodes() {
-  TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, graph_));
-  TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, graph_));
+  TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, flib_, graph_));
+  TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, flib_, graph_));
   return Status::OK();
 }
 
@@ -240,6 +244,12 @@
     return errors::Internal("Lowering If op requires a graph to be available.");
   }
 
+  FunctionLibraryDefinition* flib = options.flib_def;
+  if (flib == nullptr) {
+    return errors::Internal(
+        "Lowering If op requires a FunctionLibraryDefinition to be available.");
+  }
+
   // Match all the nodes that need to be rewritten.
   gtl::InlinedVector<Node*, 2> matches;
   for (Node* n : g->op_nodes()) {
@@ -251,12 +261,14 @@
     }
   }
   for (Node* n : matches) {
-    TF_RETURN_IF_ERROR(RewriteNode(n, g));
+    TF_RETURN_IF_ERROR(RewriteNode(n, *flib, g));
   }
   return Status::OK();
 }
 
-Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
+Status LowerIfOpPass::RewriteNode(Node* n,
+                                  const FunctionLibraryDefinition& flib,
+                                  Graph* g) {
   const AttrValue* then_attr = n->attrs().Find("then_branch");
   if (then_attr == nullptr) {
     return errors::InvalidArgument("Then branch function missing");
@@ -266,7 +278,8 @@
     return errors::InvalidArgument("Else branch function missing");
   }
 
-  CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), g);
+  CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib,
+                 g);
   TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
   TF_RETURN_IF_ERROR(cb.AddInputs());
   TF_RETURN_IF_ERROR(cb.AddOutputs());
diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h
index a9ef39a..5ab1123 100644
--- a/tensorflow/core/common_runtime/lower_if_op.h
+++ b/tensorflow/core/common_runtime/lower_if_op.h
@@ -29,8 +29,9 @@
   Status Run(const GraphOptimizationPassOptions& options) override;
 
  private:
-  // Rewrite the given If node `n` in graph `g` to use the switch-merge form.
-  Status RewriteNode(Node* n, Graph* g);
+  // Rewrite the given If node `n` in graph `g` to use the switch-merge
+  // form. `flib` should contain the branch functions referenced by `n`.
+  Status RewriteNode(Node* n, const FunctionLibraryDefinition& flib, Graph* g);
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc
index 319a617..044a355 100644
--- a/tensorflow/core/common_runtime/lower_if_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_if_op_test.cc
@@ -36,9 +36,7 @@
 namespace {
 
 Status Rewrite(std::unique_ptr<Graph>* graph) {
-  FunctionDefLibrary flib;
-  FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
-
+  FunctionLibraryDefinition flib_def((*graph)->flib_def());
   GraphOptimizationPassOptions opt_options;
   opt_options.graph = graph;
   opt_options.flib_def = &flib_def;
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index d91a848..ae61be6 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -31,6 +31,7 @@
 from tensorflow.python.client import device_lib
 from tensorflow.python.client import session
 from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors_impl
@@ -3414,6 +3415,27 @@
       self.assertAllEqual(r.numpy(), 10)
       self.assertFalse(isinstance(r, list))
 
+  def testCondInDefun(self):
+    if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
+      return unittest.skip("b/113346829 (gpu failure)")
+
+    with context.eager_mode():
+
+      @eager_function.defun
+      def foo(pred):
+        # TODO(b/111124878): this only needs to output one element.
+        fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
+        fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
+        return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
+
+      r = foo(True)
+      self.assertAllEqual(r[0].numpy(), 10)
+      self.assertNotIsInstance(r, list)
+
+      r = foo(False)
+      self.assertAllEqual(r[0].numpy(), 20)
+      self.assertFalse(isinstance(r, list))
+
   def testWhileLoop(self):
     with context.eager_mode():
       tensor = constant_op.constant([1, 2, 3, 4, 5])