blob: 811705d2b1711cd4ce0dd3f4ab706775da709573 [file] [log] [blame]
Sanjoy Dasdcf568a2018-07-23 17:49:04 -07001/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <utility>
17
Frederic Bastien4a544382020-08-21 10:20:58 -070018#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
19#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
Sanjoy Dasdcf568a2018-07-23 17:49:04 -070020#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
21#include "tensorflow/compiler/xla/service/hlo_module_config.h"
22#include "tensorflow/compiler/xla/service/hlo_parser.h"
23#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24#include "tensorflow/core/platform/test.h"
25
26namespace xla {
27namespace gpu {
28namespace {
29
30class GpuFusionTest : public GpuCodegenTest {};
31
32TEST_F(GpuFusionTest, FusedReshape) {
33 const char* hlo_text = R"(
34 HloModule test_module
35
36 fused_computation {
37 p0.param_0 = f32[4,1,1]{2,1,0} parameter(0)
38 p1.param_1 = f32[4,1]{1,0} parameter(1)
39 reshape = f32[4,1]{1,0} reshape(p0.param_0)
40 ROOT add = f32[4,1] add(reshape, p1.param_1)
41 }
42
43 ENTRY BroadcastIntoAdd {
44 p0 = f32[4,1,1]{2,1,0} parameter(0)
45 p1 = f32[4,1]{1,0} parameter(1)
46 ROOT fusion = f32[4,1]{1,0} fusion(p0, p1), kind=kLoop,
47 calls=fused_computation
48 }
49)";
50
51 CompileAndVerifyIr(hlo_text,
52 R"(
53; CHECK-LABEL: @fusion
54; CHECK: fadd
55; CHECK: }
56 )");
57}
58
Frederic Bastien4a544382020-08-21 10:20:58 -070059// Check that we limit the number of operands to fusions we create.
60TEST_F(GpuFusionTest, FusedBiggerThenThresholdButDoNotChangeTheFusionl) {
61 constexpr int64 kNumParams = kMaxOperandsAndOutputsPerFusion + 1;
62
63 // Compute
64 // p0 + p1 + p2 + ... + pn,
65 // Use so many parameters that they do not fit into one fusion.
66 auto module = CreateNewVerifiedModule();
67 HloComputation::Builder b(TestName());
68 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 100});
69 Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 2});
TensorFlower Gardenerf2e22672020-08-31 02:58:08 -070070 Shape concat_shape = ShapeUtil::MakeShape(F32, {10, 2 * kNumParams});
71 HloInstruction* input =
72 b.AddInstruction(HloInstruction::CreateParameter(0, input_shape, "p"));
Frederic Bastien4a544382020-08-21 10:20:58 -070073
74 std::vector<HloInstruction*> slice_params;
75 for (int64 i = 0; i < kNumParams; ++i) {
TensorFlower Gardenerf2e22672020-08-31 02:58:08 -070076 slice_params.push_back(b.AddInstruction(HloInstruction::CreateSlice(
77 slice_shape, input, {0, 0}, {10, 2}, {1, 1})));
Frederic Bastien4a544382020-08-21 10:20:58 -070078 }
TensorFlower Gardenerf2e22672020-08-31 02:58:08 -070079 b.AddInstruction(
80 HloInstruction::CreateConcatenate(concat_shape, slice_params, 1));
Frederic Bastien4a544382020-08-21 10:20:58 -070081 module->AddEntryComputation(b.Build());
82 EXPECT_TRUE(GpuInstructionFusion(false).Run(module.get()).ValueOrDie());
TensorFlower Gardenerf2e22672020-08-31 02:58:08 -070083 EXPECT_TRUE(module->entry_computation()->root_instruction()->opcode() ==
84 HloOpcode::kFusion);
85 for (HloInstruction* instr : module->entry_computation()->instructions()) {
Frederic Bastien4a544382020-08-21 10:20:58 -070086 EXPECT_TRUE(instr->opcode() != HloOpcode::kSlice);
87 }
88}
89
Sanjoy Dasdcf568a2018-07-23 17:49:04 -070090} // namespace
91} // namespace gpu
92} // namespace xla