[TF:XLA] Always return a tuple-shaped result when converting a TensorFlow graph into an XLA computation. Previously we had special case logic where the number of outputs was 1 to avoid the tuple, but this meant that any code that wanted to reason about the outputs had to have a special case. By using tuples unconditionally we can simplify the code.

Since loop conditions must be untupled, wrap loop condition computations with code that unpacks the output tuple into a PRED value.

[XLA:CPU] Use CallInliner pass on CPU to work around wrong output bug exposed by this change. Add a unit test that exhibits the bug if the CallInliner is disabled.

PiperOrigin-RevId: 169333679
diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
index 7e760d7..5cbff88 100644
--- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
@@ -259,7 +259,6 @@
   XlaLocalRuntimeContext local_runtime_context;
 
   std::unique_ptr<xla::ShapedBuffer> output;
-  bool output_is_tuple;
   if (!kernel->computation->IsNull()) {
     // Build xla::ShapedBuffers that point directly to the Tensor buffers.
     std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
@@ -326,7 +325,6 @@
     if (VLOG_IS_ON(2)) {
       VLOG(2) << "Result tuple shape: " << output->shape().DebugString();
     }
-    output_is_tuple = xla::ShapeUtil::IsTuple(output->shape());
   }
   CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
 
@@ -356,13 +354,7 @@
       const TensorShape& shape = kernel->outputs[i].shape;
       VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
 
-      gpu::DeviceMemoryBase buffer;
-      if (output_is_tuple) {
-        buffer = output->buffer({output_num});
-      } else {
-        CHECK_EQ(0, output_num);
-        buffer = output->buffer({});
-      }
+      gpu::DeviceMemoryBase buffer = output->buffer({output_num});
       Tensor output_tensor;
       // Looks up the owning Tensor by buffer address.
       OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer(
@@ -387,13 +379,7 @@
     TensorShape write_shape;
     OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape));
 
-    gpu::DeviceMemoryBase buffer;
-    if (output_is_tuple) {
-      buffer = output->buffer({output_num});
-    } else {
-      CHECK_EQ(0, output_num);
-      buffer = output->buffer({});
-    }
+    gpu::DeviceMemoryBase buffer = output->buffer({output_num});
 
     Var* variable = nullptr;
     // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 2c17f46..55995aa 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -96,8 +96,6 @@
   OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs(
                           ctx, &arguments, &has_uninitialized_vars));
 
-  const bool use_tuple_arg = (arguments.size() != 1);
-
   xla::ComputationBuilder* builder = ctx->builder();
   XlaCompiler* compiler = ctx->compiler();
 
@@ -112,7 +110,7 @@
   // TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
   // operator.
   XlaCompiler::CompileOptions body_options;
-  body_options.use_tuple_arg = use_tuple_arg;
+  body_options.use_tuple_arg = true;
   body_options.return_updated_values_for_all_resources = true;
   body_options.resolve_compile_time_constants = false;
   XlaCompiler::CompilationResult body;
@@ -160,22 +158,16 @@
   VLOG(1) << "Compiling condition";
 
   XlaCompiler::CompileOptions cond_options;
-  cond_options.use_tuple_arg = use_tuple_arg;
+  cond_options.use_tuple_arg = true;
   cond_options.resolve_compile_time_constants = false;
   XlaCompiler::CompilationResult cond;
   OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
                                                 arguments, &cond));
 
-  xla::Shape body_input_shape, cond_input_shape;
-  if (use_tuple_arg) {
-    body_input_shape = xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes);
-    cond_input_shape = xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes);
-  } else {
-    CHECK(!body.xla_input_shapes.empty());
-    body_input_shape = body.xla_input_shapes[0];
-    CHECK(!cond.xla_input_shapes.empty());
-    cond_input_shape = cond.xla_input_shapes[0];
-  }
+  xla::Shape body_input_shape =
+      xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes);
+  xla::Shape cond_input_shape =
+      xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes);
 
   VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
           << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
@@ -195,10 +187,16 @@
           xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
           xla::ShapeUtil::HumanString(body.xla_output_shape)));
 
-  xla::ComputationDataHandle data;
+  xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape(
+      {xla::ShapeUtil::MakeShape(xla::PRED, {})});
+  OP_REQUIRES(ctx,
+              xla::ShapeUtil::Compatible(cond.xla_output_shape,
+                                         expected_cond_output_shape),
+              errors::InvalidArgument(
+                  "Output shape of loop condition should be (pred[]), got: ",
+                  xla::ShapeUtil::HumanString(cond.xla_output_shape)));
 
   int num_inputs = body.input_mapping.size();
-
   std::vector<xla::ComputationDataHandle> inputs(num_inputs);
   for (int i = 0; i < num_inputs; ++i) {
     int input_num = body.input_mapping[i];
@@ -211,30 +209,31 @@
     }
   }
 
-  xla::ComputationDataHandle init;
-  if (use_tuple_arg) {
-    init = builder->Tuple(inputs);
-  } else {
-    init = inputs[0];
-  }
+  xla::ComputationDataHandle init = builder->Tuple(inputs);
 
   VLOG(1) << "Building while loop";
 
-  xla::ComputationDataHandle while_result =
-      builder->While(*cond.computation, *body.computation, init);
+  // Wraps the condition in a computation that unpacks the output tuple.
+  xla::Computation cond_wrapper;
+  {
+    std::unique_ptr<xla::ComputationBuilder> cb =
+        builder->CreateSubBuilder("cond_wrapper");
+    auto inputs = cb->Parameter(0, cond_input_shape, "inputs");
+    auto outputs = cb->Call(*cond.computation, {inputs});
+    cb->GetTupleElement(outputs, 0);
+    xla::StatusOr<xla::Computation> result = cb->Build();
+    OP_REQUIRES_OK(ctx, result.status());
+    cond_wrapper = std::move(result.ValueOrDie());
+  }
 
-  auto get_loop_output = [&](int i) {
-    if (use_tuple_arg) {
-      return builder->GetTupleElement(while_result, i);
-    } else {
-      return while_result;
-    }
-  };
+  xla::ComputationDataHandle while_result =
+      builder->While(cond_wrapper, *body.computation, init);
 
   // Sets non-variable outputs.
   for (int i = 0; i < ctx->num_outputs(); ++i) {
     if (ctx->input_type(i) != DT_RESOURCE) {
-      ctx->SetOutput(body.input_mapping[i], get_loop_output(i));
+      ctx->SetOutput(body.input_mapping[i],
+                     builder->GetTupleElement(while_result, i));
     }
   }
 
@@ -245,7 +244,7 @@
     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
     if (update.modified) {
       int pos = body.outputs.size() + i;
-      resource->value = get_loop_output(pos);
+      resource->value = builder->GetTupleElement(while_result, pos);
     }
     VLOG(2) << "Loop-carried variable: pos: " << update.input_index
             << " name: " << resource->name << " modified: " << update.modified
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 57b53cc..51ce17d 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -92,7 +92,7 @@
       client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
   TF_EXPECT_OK(result_or.status());
   std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
-  EXPECT_EQ("42", result->ToString());
+  EXPECT_EQ("(s32[]) (\n42,\n)", result->ToString());
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 66c91ae..08b9faa 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -385,14 +385,7 @@
   if (!elems.empty() || has_side_effects) {
     // Builds a empty tuple return value for computations that have side effects
     // but have no return values.
-    xla::ComputationDataHandle handle = builder->Tuple(elems);
-
-    // TODO(b/31775371): to workaround bug, we must build a no-op computation
-    // that is guaranteed to be constructed after all of the formal parameters
-    // to the computation. Once the bug is fixed, we could avoid tupling here.
-    if (elems.size() == 1) {
-      handle = builder->GetTupleElement(handle, 0);
-    }
+    builder->Tuple(elems);
 
     // Builds the XLA computation.
     xla::StatusOr<xla::Computation> computation_status = builder->Build();
@@ -512,28 +505,18 @@
       CHECK_LT(computation_output, num_computation_outputs);
       OutputDescription& output = result->outputs[i];
       output.is_constant = false;
-      if (num_computation_outputs > 1) {
-        TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
-            xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
-                                                 computation_output),
-            &output.shape));
-      } else {
-        TF_RETURN_IF_ERROR(
-            XLAShapeToTensorShape(result->xla_output_shape, &output.shape));
-      }
+      TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
+          xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
+                                               computation_output),
+          &output.shape));
       ++computation_output;
     }
   }
 
   for (std::vector<ResourceUpdate>::size_type i = 0;
        i < result->resource_updates.size(); ++i) {
-    if (num_computation_outputs > 1) {
-      result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
-          result->xla_output_shape, computation_output);
-    } else {
-      CHECK_EQ(0, computation_output);
-      result->resource_updates[i].shape = result->xla_output_shape;
-    }
+    result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
+        result->xla_output_shape, computation_output);
     ++computation_output;
   }
   return Status::OK();
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 6727eb6..809f668 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -165,8 +165,7 @@
     // Should the arguments be packed into a single tuple?
     bool tuple_arg;
 
-    // Output shape in XLA format. The output shape is a tuple if and only if
-    // the number of non-constant outputs is not equal to 1.
+    // Output shape in XLA format. The output shape is always a tuple.
     xla::Shape xla_output_shape;
 
     // TensorFlow shapes of outputs, together with the values of any
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index a1e4dcb..aa8df80 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -208,8 +208,10 @@
   std::unique_ptr<xla::Literal> actual_literal =
       client_->Transfer(*actual).ConsumeValueOrDie();
 
-  std::unique_ptr<xla::Literal> expected_literal =
+  std::unique_ptr<xla::Literal> expected0 =
       xla::Literal::CreateR1<int32>({4, 143});
+  std::unique_ptr<xla::Literal> expected_literal =
+      xla::Literal::MakeTuple({expected0.get()});
   xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
 }
 
@@ -265,8 +267,10 @@
     std::unique_ptr<xla::Literal> actual_literal =
         client_->Transfer(*actual).ConsumeValueOrDie();
 
-    std::unique_ptr<xla::Literal> expected_literal =
+    std::unique_ptr<xla::Literal> expected0 =
         xla::Literal::CreateR1<int32>({-7, -42});
+    std::unique_ptr<xla::Literal> expected_literal =
+        xla::Literal::MakeTuple({expected0.get()});
     xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
   }
 
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 1a18b28..cd25d67 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -53,6 +53,7 @@
         "//tensorflow/compiler/xla/service:batchnorm_rewriter",
         "//tensorflow/compiler/xla/service:buffer_assignment",
         "//tensorflow/compiler/xla/service:buffer_liveness",
+        "//tensorflow/compiler/xla/service:call_inliner",
         "//tensorflow/compiler/xla/service:copy_insertion",
         "//tensorflow/compiler/xla/service:executable",
         "//tensorflow/compiler/xla/service:flatten_call_graph",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 8d77a63..afd9f72 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -45,6 +45,7 @@
 #include "tensorflow/compiler/xla/service/batchnorm_rewriter.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
+#include "tensorflow/compiler/xla/service/call_inliner.h"
 #include "tensorflow/compiler/xla/service/copy_insertion.h"
 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
@@ -261,6 +262,10 @@
   // where we will take this pass in future.
   // pipeline.AddPass<Inliner>();
 
+  // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
+  // pass.
+  pipeline.AddPass<CallInliner>();
+
   pipeline.AddPass<ConvCanonicalization>();
   {
     auto& pass =
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
index 89f9b8a..569d594 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -44,8 +44,8 @@
   OpaqueData opaque_data{100, 20, 3};
   void* parameters[] = {&opaque_data};
   float out = 0;
-  char tmp[20] = {0};
-  void* temporary_buffers[] = {&out, nullptr, &tmp};
+  char tmp[4] = {0};
+  void* temporary_buffers[] = {nullptr, &out, &tmp};
   SumAndDouble(&out, &run_options, parameters, temporary_buffers);
   EXPECT_EQ(out, 246.0f);
 
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 9266760..0cd44a7 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -88,11 +88,11 @@
       std::move(results.front()));
   // It's lame to hard-code the buffer assignments, but we need
   // local_client_aot_test.cc to be able to easily invoke the function.
-  CHECK_EQ(result->result_buffer_index(), 0);
+  CHECK_EQ(result->result_buffer_index(), 1);
   CHECK_EQ(result->buffer_sizes().size(), 3);
-  CHECK_EQ(result->buffer_sizes()[0], sizeof(float));  // result buffer
-  CHECK_EQ(result->buffer_sizes()[1], -1);             // param buffer
-  CHECK_EQ(result->buffer_sizes()[2], 20);             // temp buffer
+  CHECK_EQ(result->buffer_sizes()[0], -1);             // param buffer
+  CHECK_EQ(result->buffer_sizes()[1], sizeof(float));  // result buffer
+  CHECK_EQ(result->buffer_sizes()[2], sizeof(float));  // temp buffer
   if (triple.isOSBinFormatELF()) {
     // Check the ELF magic.
     CHECK_EQ(result->object_file_data()[0], 0x7F);
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index a0f9be3..bb2d90f 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -967,6 +967,53 @@
   ComputeAndCompareR0<int32>(&builder, 42, {});
 }
 
+// Tests a while node when the result type T is S32.
+// f = lambda result: tuple({result < 5})
+// int32 result = 0;
+// while (f(result).get<0>()) {
+//   result = result + 1;
+// }
+TEST_F(WhileTest, WhileWithCallInsideCondition) {
+  auto result_shape = ShapeUtil::MakeShape(S32, {});
+
+  // Create a computation for the condition: repeat for 5 iterations.
+  Computation condition_callee;
+  {
+    ComputationBuilder builder(client_, "condition_callee");
+    auto prev = builder.Parameter(0, result_shape, "prev");
+    builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)});
+
+    condition_callee = builder.Build().ConsumeValueOrDie();
+  }
+
+  Computation condition;
+  {
+    ComputationBuilder builder(client_, "condition");
+    auto prev = builder.Parameter(0, result_shape, "prev");
+    auto result = builder.Call(condition_callee, {prev});
+    builder.GetTupleElement(result, 0);
+    condition = builder.Build().ConsumeValueOrDie();
+  }
+
+  // Create a computation for the body: add 1 to the result variable.
+  Computation body;
+  {
+    ComputationBuilder builder(client_, "body");
+    auto prev = builder.Parameter(0, result_shape, "prev");
+    auto input = builder.ConstantR0<int32>(1);
+    auto result = builder.Add(input, prev);
+    body = builder.Build().ConsumeValueOrDie();
+  }
+
+  // Create a While node with computations for the condition and the body.
+  ComputationBuilder builder(client_, TestName());
+  auto init = builder.ConstantR0<int32>(0);
+  auto result = builder.While(condition, body, init);
+  auto shape = builder.GetShape(result).ConsumeValueOrDie();
+
+  ComputeAndCompareR0<int32>(&builder, 5, {});
+}
+
 void BM_WhileLoop(int num_iters) {
   // Benchmark a simple kernel to measure while loop overheads.
   tensorflow::testing::StopTiming();
diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
index db811bd..1442693 100644
--- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
+++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
@@ -112,7 +112,7 @@
       std::unique_ptr<xla::SessionModule> session_module,
       ConvertTfGraphToXlaSessionModule(args, std::move(graph)));
 
-  ASSERT_EQ(5, session_module->entry().requests_size());
+  ASSERT_EQ(4, session_module->entry().requests_size());
 
   VLOG(1) << "--- DUMP ---";
   VLOG(1) << session_module->DebugString();