[TF:XLA] Always return a tuple-shaped result when converting a TensorFlow graph into an XLA computation. Previously we had special case logic where the number of outputs was 1 to avoid the tuple, but this meant that any code that wanted to reason about the outputs had to have a special case. By using tuples unconditionally we can simplify the code.
Since loop conditions must be untupled, wrap loop condition computations with code that unpacks the output tuple into a PRED value.
[XLA:CPU] Use CallInliner pass on CPU to work around wrong output bug exposed by this change. Add a unit test that exhibits the bug if the CallInliner is disabled.
PiperOrigin-RevId: 169333679
diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
index 7e760d7..5cbff88 100644
--- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
@@ -259,7 +259,6 @@
XlaLocalRuntimeContext local_runtime_context;
std::unique_ptr<xla::ShapedBuffer> output;
- bool output_is_tuple;
if (!kernel->computation->IsNull()) {
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
@@ -326,7 +325,6 @@
if (VLOG_IS_ON(2)) {
VLOG(2) << "Result tuple shape: " << output->shape().DebugString();
}
- output_is_tuple = xla::ShapeUtil::IsTuple(output->shape());
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
@@ -356,13 +354,7 @@
const TensorShape& shape = kernel->outputs[i].shape;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
- gpu::DeviceMemoryBase buffer;
- if (output_is_tuple) {
- buffer = output->buffer({output_num});
- } else {
- CHECK_EQ(0, output_num);
- buffer = output->buffer({});
- }
+ gpu::DeviceMemoryBase buffer = output->buffer({output_num});
Tensor output_tensor;
// Looks up the owning Tensor by buffer address.
OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer(
@@ -387,13 +379,7 @@
TensorShape write_shape;
OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape));
- gpu::DeviceMemoryBase buffer;
- if (output_is_tuple) {
- buffer = output->buffer({output_num});
- } else {
- CHECK_EQ(0, output_num);
- buffer = output->buffer({});
- }
+ gpu::DeviceMemoryBase buffer = output->buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, not
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 2c17f46..55995aa 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -96,8 +96,6 @@
OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs(
ctx, &arguments, &has_uninitialized_vars));
- const bool use_tuple_arg = (arguments.size() != 1);
-
xla::ComputationBuilder* builder = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
@@ -112,7 +110,7 @@
// TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
// operator.
XlaCompiler::CompileOptions body_options;
- body_options.use_tuple_arg = use_tuple_arg;
+ body_options.use_tuple_arg = true;
body_options.return_updated_values_for_all_resources = true;
body_options.resolve_compile_time_constants = false;
XlaCompiler::CompilationResult body;
@@ -160,22 +158,16 @@
VLOG(1) << "Compiling condition";
XlaCompiler::CompileOptions cond_options;
- cond_options.use_tuple_arg = use_tuple_arg;
+ cond_options.use_tuple_arg = true;
cond_options.resolve_compile_time_constants = false;
XlaCompiler::CompilationResult cond;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_,
arguments, &cond));
- xla::Shape body_input_shape, cond_input_shape;
- if (use_tuple_arg) {
- body_input_shape = xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes);
- cond_input_shape = xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes);
- } else {
- CHECK(!body.xla_input_shapes.empty());
- body_input_shape = body.xla_input_shapes[0];
- CHECK(!cond.xla_input_shapes.empty());
- cond_input_shape = cond.xla_input_shapes[0];
- }
+ xla::Shape body_input_shape =
+ xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes);
+ xla::Shape cond_input_shape =
+ xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes);
VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape)
<< " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape);
@@ -195,10 +187,16 @@
xla::ShapeUtil::HumanString(body_input_shape), " vs. ",
xla::ShapeUtil::HumanString(body.xla_output_shape)));
- xla::ComputationDataHandle data;
+ xla::Shape expected_cond_output_shape = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::PRED, {})});
+ OP_REQUIRES(ctx,
+ xla::ShapeUtil::Compatible(cond.xla_output_shape,
+ expected_cond_output_shape),
+ errors::InvalidArgument(
+ "Output shape of loop condition should be (pred[]), got: ",
+ xla::ShapeUtil::HumanString(cond.xla_output_shape)));
int num_inputs = body.input_mapping.size();
-
std::vector<xla::ComputationDataHandle> inputs(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
@@ -211,30 +209,31 @@
}
}
- xla::ComputationDataHandle init;
- if (use_tuple_arg) {
- init = builder->Tuple(inputs);
- } else {
- init = inputs[0];
- }
+ xla::ComputationDataHandle init = builder->Tuple(inputs);
VLOG(1) << "Building while loop";
- xla::ComputationDataHandle while_result =
- builder->While(*cond.computation, *body.computation, init);
+ // Wraps the condition in a computation that unpacks the output tuple.
+ xla::Computation cond_wrapper;
+ {
+ std::unique_ptr<xla::ComputationBuilder> cb =
+ builder->CreateSubBuilder("cond_wrapper");
+ auto inputs = cb->Parameter(0, cond_input_shape, "inputs");
+ auto outputs = cb->Call(*cond.computation, {inputs});
+ cb->GetTupleElement(outputs, 0);
+ xla::StatusOr<xla::Computation> result = cb->Build();
+ OP_REQUIRES_OK(ctx, result.status());
+ cond_wrapper = std::move(result.ValueOrDie());
+ }
- auto get_loop_output = [&](int i) {
- if (use_tuple_arg) {
- return builder->GetTupleElement(while_result, i);
- } else {
- return while_result;
- }
- };
+ xla::ComputationDataHandle while_result =
+ builder->While(cond_wrapper, *body.computation, init);
// Sets non-variable outputs.
for (int i = 0; i < ctx->num_outputs(); ++i) {
if (ctx->input_type(i) != DT_RESOURCE) {
- ctx->SetOutput(body.input_mapping[i], get_loop_output(i));
+ ctx->SetOutput(body.input_mapping[i],
+ builder->GetTupleElement(while_result, i));
}
}
@@ -245,7 +244,7 @@
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
if (update.modified) {
int pos = body.outputs.size() + i;
- resource->value = get_loop_output(pos);
+ resource->value = builder->GetTupleElement(while_result, pos);
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name << " modified: " << update.modified
diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc
index 57b53cc..51ce17d 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_test.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc
@@ -92,7 +92,7 @@
client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()});
TF_EXPECT_OK(result_or.status());
std::unique_ptr<xla::Literal> result = std::move(result_or.ValueOrDie());
- EXPECT_EQ("42", result->ToString());
+ EXPECT_EQ("(s32[]) (\n42,\n)", result->ToString());
}
} // namespace
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 66c91ae..08b9faa 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -385,14 +385,7 @@
if (!elems.empty() || has_side_effects) {
// Builds a empty tuple return value for computations that have side effects
// but have no return values.
- xla::ComputationDataHandle handle = builder->Tuple(elems);
-
- // TODO(b/31775371): to workaround bug, we must build a no-op computation
- // that is guaranteed to be constructed after all of the formal parameters
- // to the computation. Once the bug is fixed, we could avoid tupling here.
- if (elems.size() == 1) {
- handle = builder->GetTupleElement(handle, 0);
- }
+ builder->Tuple(elems);
// Builds the XLA computation.
xla::StatusOr<xla::Computation> computation_status = builder->Build();
@@ -512,28 +505,18 @@
CHECK_LT(computation_output, num_computation_outputs);
OutputDescription& output = result->outputs[i];
output.is_constant = false;
- if (num_computation_outputs > 1) {
- TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
- xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
- computation_output),
- &output.shape));
- } else {
- TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(result->xla_output_shape, &output.shape));
- }
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
+ xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
+ computation_output),
+ &output.shape));
++computation_output;
}
}
for (std::vector<ResourceUpdate>::size_type i = 0;
i < result->resource_updates.size(); ++i) {
- if (num_computation_outputs > 1) {
- result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
- result->xla_output_shape, computation_output);
- } else {
- CHECK_EQ(0, computation_output);
- result->resource_updates[i].shape = result->xla_output_shape;
- }
+ result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
+ result->xla_output_shape, computation_output);
++computation_output;
}
return Status::OK();
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 6727eb6..809f668 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -165,8 +165,7 @@
// Should the arguments be packed into a single tuple?
bool tuple_arg;
- // Output shape in XLA format. The output shape is a tuple if and only if
- // the number of non-constant outputs is not equal to 1.
+ // Output shape in XLA format. The output shape is always a tuple.
xla::Shape xla_output_shape;
// TensorFlow shapes of outputs, together with the values of any
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index a1e4dcb..aa8df80 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -208,8 +208,10 @@
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected_literal =
+ std::unique_ptr<xla::Literal> expected0 =
xla::Literal::CreateR1<int32>({4, 143});
+ std::unique_ptr<xla::Literal> expected_literal =
+ xla::Literal::MakeTuple({expected0.get()});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
}
@@ -265,8 +267,10 @@
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> expected_literal =
+ std::unique_ptr<xla::Literal> expected0 =
xla::Literal::CreateR1<int32>({-7, -42});
+ std::unique_ptr<xla::Literal> expected_literal =
+ xla::Literal::MakeTuple({expected0.get()});
xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
}
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 1a18b28..cd25d67 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -53,6 +53,7 @@
"//tensorflow/compiler/xla/service:batchnorm_rewriter",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
+ "//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 8d77a63..afd9f72 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -45,6 +45,7 @@
#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
@@ -261,6 +262,10 @@
// where we will take this pass in future.
// pipeline.AddPass<Inliner>();
+ // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
+ // pass.
+ pipeline.AddPass<CallInliner>();
+
pipeline.AddPass<ConvCanonicalization>();
{
auto& pass =
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
index 89f9b8a..569d594 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -44,8 +44,8 @@
OpaqueData opaque_data{100, 20, 3};
void* parameters[] = {&opaque_data};
float out = 0;
- char tmp[20] = {0};
- void* temporary_buffers[] = {&out, nullptr, &tmp};
+ char tmp[4] = {0};
+ void* temporary_buffers[] = {nullptr, &out, &tmp};
SumAndDouble(&out, &run_options, parameters, temporary_buffers);
EXPECT_EQ(out, 246.0f);
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 9266760..0cd44a7 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -88,11 +88,11 @@
std::move(results.front()));
// It's lame to hard-code the buffer assignments, but we need
// local_client_aot_test.cc to be able to easily invoke the function.
- CHECK_EQ(result->result_buffer_index(), 0);
+ CHECK_EQ(result->result_buffer_index(), 1);
CHECK_EQ(result->buffer_sizes().size(), 3);
- CHECK_EQ(result->buffer_sizes()[0], sizeof(float)); // result buffer
- CHECK_EQ(result->buffer_sizes()[1], -1); // param buffer
- CHECK_EQ(result->buffer_sizes()[2], 20); // temp buffer
+ CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer
+ CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer
+ CHECK_EQ(result->buffer_sizes()[2], sizeof(float)); // temp buffer
if (triple.isOSBinFormatELF()) {
// Check the ELF magic.
CHECK_EQ(result->object_file_data()[0], 0x7F);
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index a0f9be3..bb2d90f 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -967,6 +967,53 @@
ComputeAndCompareR0<int32>(&builder, 42, {});
}
+// Tests a while node when the result type T is S32.
+// f = lambda result: tuple({result < 5})
+// int32 result = 0;
+// while (f(result).get<0>()) {
+// result = result + 1;
+// }
+TEST_F(WhileTest, WhileWithCallInsideCondition) {
+ auto result_shape = ShapeUtil::MakeShape(S32, {});
+
+ // Create a computation for the condition: repeat for 5 iterations.
+ Computation condition_callee;
+ {
+ ComputationBuilder builder(client_, "condition_callee");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)});
+
+ condition_callee = builder.Build().ConsumeValueOrDie();
+ }
+
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto result = builder.Call(condition_callee, {prev});
+ builder.GetTupleElement(result, 0);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body: add 1 to the result variable.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR0<int32>(1);
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, TestName());
+ auto init = builder.ConstantR0<int32>(0);
+ auto result = builder.While(condition, body, init);
+ auto shape = builder.GetShape(result).ConsumeValueOrDie();
+
+ ComputeAndCompareR0<int32>(&builder, 5, {});
+}
+
void BM_WhileLoop(int num_iters) {
// Benchmark a simple kernel to measure while loop overheads.
tensorflow::testing::StopTiming();
diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
index db811bd..1442693 100644
--- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
+++ b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc
@@ -112,7 +112,7 @@
std::unique_ptr<xla::SessionModule> session_module,
ConvertTfGraphToXlaSessionModule(args, std::move(graph)));
- ASSERT_EQ(5, session_module->entry().requests_size());
+ ASSERT_EQ(4, session_module->entry().requests_size());
VLOG(1) << "--- DUMP ---";
VLOG(1) << session_module->DebugString();