blob: 9416e592012de5fa244a6db1cb8be3d1e995099c [file] [log] [blame]
Nicolai Haehnle59041682018-10-18 09:38:44 +00001//===- DivergenceAnalysisTest.cpp - DivergenceAnalysis unit tests ---------===//
2//
Chandler Carruth2946cd72019-01-19 08:50:56 +00003// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Nicolai Haehnle59041682018-10-18 09:38:44 +00006//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/ADT/SmallVector.h"
10#include "llvm/Analysis/AssumptionCache.h"
11#include "llvm/Analysis/DivergenceAnalysis.h"
12#include "llvm/Analysis/LoopInfo.h"
13#include "llvm/Analysis/PostDominators.h"
14#include "llvm/Analysis/SyncDependenceAnalysis.h"
15#include "llvm/Analysis/TargetLibraryInfo.h"
16#include "llvm/AsmParser/Parser.h"
17#include "llvm/IR/Constants.h"
18#include "llvm/IR/Dominators.h"
19#include "llvm/IR/GlobalVariable.h"
20#include "llvm/IR/IRBuilder.h"
21#include "llvm/IR/InstIterator.h"
22#include "llvm/IR/LLVMContext.h"
23#include "llvm/IR/LegacyPassManager.h"
24#include "llvm/IR/Module.h"
25#include "llvm/IR/Verifier.h"
26#include "llvm/Support/SourceMgr.h"
27#include "gtest/gtest.h"
28
29namespace llvm {
30namespace {
31
32BasicBlock *GetBlockByName(StringRef BlockName, Function &F) {
33 for (auto &BB : F) {
34 if (BB.getName() != BlockName)
35 continue;
36 return &BB;
37 }
38 return nullptr;
39}
40
41// We use this fixture to ensure that we clean up DivergenceAnalysis before
42// deleting the PassManager.
43class DivergenceAnalysisTest : public testing::Test {
44protected:
45 LLVMContext Context;
46 Module M;
47 TargetLibraryInfoImpl TLII;
48 TargetLibraryInfo TLI;
49
50 std::unique_ptr<DominatorTree> DT;
51 std::unique_ptr<PostDominatorTree> PDT;
52 std::unique_ptr<LoopInfo> LI;
53 std::unique_ptr<SyncDependenceAnalysis> SDA;
54
55 DivergenceAnalysisTest() : M("", Context), TLII(), TLI(TLII) {}
56
57 DivergenceAnalysis buildDA(Function &F, bool IsLCSSA) {
58 DT.reset(new DominatorTree(F));
59 PDT.reset(new PostDominatorTree(F));
60 LI.reset(new LoopInfo(*DT));
61 SDA.reset(new SyncDependenceAnalysis(*DT, *PDT, *LI));
62 return DivergenceAnalysis(F, nullptr, *DT, *LI, *SDA, IsLCSSA);
63 }
64
65 void runWithDA(
66 Module &M, StringRef FuncName, bool IsLCSSA,
67 function_ref<void(Function &F, LoopInfo &LI, DivergenceAnalysis &DA)>
68 Test) {
69 auto *F = M.getFunction(FuncName);
70 ASSERT_NE(F, nullptr) << "Could not find " << FuncName;
71 DivergenceAnalysis DA = buildDA(*F, IsLCSSA);
72 Test(*F, *LI, DA);
73 }
74};
75
76// Simple initial state test
77TEST_F(DivergenceAnalysisTest, DAInitialState) {
78 IntegerType *IntTy = IntegerType::getInt32Ty(Context);
79 FunctionType *FTy =
80 FunctionType::get(Type::getVoidTy(Context), {IntTy}, false);
James Y Knight13680222019-02-01 02:28:03 +000081 Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
Nicolai Haehnle59041682018-10-18 09:38:44 +000082 BasicBlock *BB = BasicBlock::Create(Context, "entry", F);
83 ReturnInst::Create(Context, nullptr, BB);
84
85 DivergenceAnalysis DA = buildDA(*F, false);
86
87 // Whole function region
88 EXPECT_EQ(DA.getRegionLoop(), nullptr);
89
90 // No divergence in initial state
91 EXPECT_FALSE(DA.hasDetectedDivergence());
92
93 // No spurious divergence
94 DA.compute();
95 EXPECT_FALSE(DA.hasDetectedDivergence());
96
97 // Detected divergence after marking
98 Argument &arg = *F->arg_begin();
99 DA.markDivergent(arg);
100
101 EXPECT_TRUE(DA.hasDetectedDivergence());
102 EXPECT_TRUE(DA.isDivergent(arg));
103
104 DA.compute();
105 EXPECT_TRUE(DA.hasDetectedDivergence());
106 EXPECT_TRUE(DA.isDivergent(arg));
107}
108
109TEST_F(DivergenceAnalysisTest, DANoLCSSA) {
110 LLVMContext C;
111 SMDiagnostic Err;
112
113 std::unique_ptr<Module> M = parseAssemblyString(
114 "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
115 " "
116 "define i32 @f_1(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) "
117 " local_unnamed_addr { "
118 "entry: "
119 " br label %loop.ph "
120 " "
121 "loop.ph: "
122 " br label %loop "
123 " "
124 "loop: "
125 " %iv0 = phi i32 [ %iv0.inc, %loop ], [ 0, %loop.ph ] "
126 " %iv1 = phi i32 [ %iv1.inc, %loop ], [ -2147483648, %loop.ph ] "
127 " %iv0.inc = add i32 %iv0, 1 "
128 " %iv1.inc = add i32 %iv1, 3 "
129 " %cond.cont = icmp slt i32 %iv0, %n "
130 " br i1 %cond.cont, label %loop, label %for.end.loopexit "
131 " "
132 "for.end.loopexit: "
133 " ret i32 %iv0 "
134 "} ",
135 Err, C);
136
137 Function *F = M->getFunction("f_1");
138 DivergenceAnalysis DA = buildDA(*F, false);
139 EXPECT_FALSE(DA.hasDetectedDivergence());
140
141 auto ItArg = F->arg_begin();
142 ItArg++;
143 auto &NArg = *ItArg;
144
145 // Seed divergence in argument %n
146 DA.markDivergent(NArg);
147
148 DA.compute();
149 EXPECT_TRUE(DA.hasDetectedDivergence());
150
151 // Verify that "ret %iv.0" is divergent
152 auto ItBlock = F->begin();
153 std::advance(ItBlock, 3);
154 auto &ExitBlock = *GetBlockByName("for.end.loopexit", *F);
155 auto &RetInst = *cast<ReturnInst>(ExitBlock.begin());
156 EXPECT_TRUE(DA.isDivergent(RetInst));
157}
158
159TEST_F(DivergenceAnalysisTest, DALCSSA) {
160 LLVMContext C;
161 SMDiagnostic Err;
162
163 std::unique_ptr<Module> M = parseAssemblyString(
164 "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
165 " "
166 "define i32 @f_lcssa(i8* nocapture %arr, i32 %n, i32* %A, i32* %B) "
167 " local_unnamed_addr { "
168 "entry: "
169 " br label %loop.ph "
170 " "
171 "loop.ph: "
172 " br label %loop "
173 " "
174 "loop: "
175 " %iv0 = phi i32 [ %iv0.inc, %loop ], [ 0, %loop.ph ] "
176 " %iv1 = phi i32 [ %iv1.inc, %loop ], [ -2147483648, %loop.ph ] "
177 " %iv0.inc = add i32 %iv0, 1 "
178 " %iv1.inc = add i32 %iv1, 3 "
179 " %cond.cont = icmp slt i32 %iv0, %n "
180 " br i1 %cond.cont, label %loop, label %for.end.loopexit "
181 " "
182 "for.end.loopexit: "
183 " %val.ret = phi i32 [ %iv0, %loop ] "
184 " br label %detached.return "
185 " "
186 "detached.return: "
187 " ret i32 %val.ret "
188 "} ",
189 Err, C);
190
191 Function *F = M->getFunction("f_lcssa");
192 DivergenceAnalysis DA = buildDA(*F, true);
193 EXPECT_FALSE(DA.hasDetectedDivergence());
194
195 auto ItArg = F->arg_begin();
196 ItArg++;
197 auto &NArg = *ItArg;
198
199 // Seed divergence in argument %n
200 DA.markDivergent(NArg);
201
202 DA.compute();
203 EXPECT_TRUE(DA.hasDetectedDivergence());
204
205 // Verify that "ret %iv.0" is divergent
206 auto ItBlock = F->begin();
207 std::advance(ItBlock, 4);
208 auto &ExitBlock = *GetBlockByName("detached.return", *F);
209 auto &RetInst = *cast<ReturnInst>(ExitBlock.begin());
210 EXPECT_TRUE(DA.isDivergent(RetInst));
211}
212
213TEST_F(DivergenceAnalysisTest, DAJoinDivergence) {
214 LLVMContext C;
215 SMDiagnostic Err;
216
217 std::unique_ptr<Module> M = parseAssemblyString(
218 "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
219 " "
220 "define void @f_1(i1 %a, i1 %b, i1 %c) "
221 " local_unnamed_addr { "
222 "A: "
223 " br i1 %a, label %B, label %C "
224 " "
225 "B: "
226 " br i1 %b, label %C, label %D "
227 " "
228 "C: "
229 " %c.join = phi i32 [ 0, %A ], [ 1, %B ] "
230 " br i1 %c, label %D, label %E "
231 " "
232 "D: "
233 " %d.join = phi i32 [ 0, %B ], [ 1, %C ] "
234 " br label %E "
235 " "
236 "E: "
237 " %e.join = phi i32 [ 0, %C ], [ 1, %D ] "
238 " ret void "
239 "} "
240 " "
241 "define void @f_2(i1 %a, i1 %b, i1 %c) "
242 " local_unnamed_addr { "
243 "A: "
244 " br i1 %a, label %B, label %E "
245 " "
246 "B: "
247 " br i1 %b, label %C, label %D "
248 " "
249 "C: "
250 " br label %D "
251 " "
252 "D: "
253 " %d.join = phi i32 [ 0, %B ], [ 1, %C ] "
254 " br label %E "
255 " "
256 "E: "
257 " %e.join = phi i32 [ 0, %A ], [ 1, %D ] "
258 " ret void "
259 "} "
260 " "
261 "define void @f_3(i1 %a, i1 %b, i1 %c)"
262 " local_unnamed_addr { "
263 "A: "
264 " br i1 %a, label %B, label %C "
265 " "
266 "B: "
267 " br label %C "
268 " "
269 "C: "
270 " %c.join = phi i32 [ 0, %A ], [ 1, %B ] "
271 " br i1 %c, label %D, label %E "
272 " "
273 "D: "
274 " br label %E "
275 " "
276 "E: "
277 " %e.join = phi i32 [ 0, %C ], [ 1, %D ] "
278 " ret void "
279 "} ",
280 Err, C);
281
282 // Maps divergent conditions to the basic blocks whose Phi nodes become
283 // divergent. Blocks need to be listed in IR order.
284 using SmallBlockVec = SmallVector<const BasicBlock *, 4>;
285 using InducedDivJoinMap = std::map<const Value *, SmallBlockVec>;
286
287 // Actual function performing the checks.
288 auto CheckDivergenceFunc = [this](Function &F,
289 InducedDivJoinMap &ExpectedDivJoins) {
290 for (auto &ItCase : ExpectedDivJoins) {
291 auto *DivVal = ItCase.first;
292 auto DA = buildDA(F, false);
293 DA.markDivergent(*DivVal);
294 DA.compute();
295
296 // List of basic blocks that shall host divergent Phi nodes.
297 auto ItDivJoins = ItCase.second.begin();
298
299 for (auto &BB : F) {
300 auto *Phi = dyn_cast<PHINode>(BB.begin());
301 if (!Phi)
302 continue;
303
Nicolai Haehnle7052cb32018-10-18 12:54:39 +0000304 if (ItDivJoins != ItCase.second.end() && &BB == *ItDivJoins) {
Nicolai Haehnle59041682018-10-18 09:38:44 +0000305 EXPECT_TRUE(DA.isDivergent(*Phi));
306 // Advance to next block with expected divergent PHI node.
307 ++ItDivJoins;
308 } else {
309 EXPECT_FALSE(DA.isDivergent(*Phi));
310 }
311 }
312 }
313 };
314
315 {
316 auto *F = M->getFunction("f_1");
317 auto ItBlocks = F->begin();
318 ItBlocks++; // Skip A
319 ItBlocks++; // Skip B
320 auto *C = &*ItBlocks++;
321 auto *D = &*ItBlocks++;
322 auto *E = &*ItBlocks;
323
324 auto ItArg = F->arg_begin();
325 auto *AArg = &*ItArg++;
326 auto *BArg = &*ItArg++;
327 auto *CArg = &*ItArg;
328
329 InducedDivJoinMap DivJoins;
330 DivJoins.emplace(AArg, SmallBlockVec({C, D, E}));
331 DivJoins.emplace(BArg, SmallBlockVec({D, E}));
332 DivJoins.emplace(CArg, SmallBlockVec({E}));
333
334 CheckDivergenceFunc(*F, DivJoins);
335 }
336
337 {
338 auto *F = M->getFunction("f_2");
339 auto ItBlocks = F->begin();
340 ItBlocks++; // Skip A
341 ItBlocks++; // Skip B
342 ItBlocks++; // Skip C
343 auto *D = &*ItBlocks++;
344 auto *E = &*ItBlocks;
345
346 auto ItArg = F->arg_begin();
347 auto *AArg = &*ItArg++;
348 auto *BArg = &*ItArg++;
349 auto *CArg = &*ItArg;
350
351 InducedDivJoinMap DivJoins;
352 DivJoins.emplace(AArg, SmallBlockVec({E}));
353 DivJoins.emplace(BArg, SmallBlockVec({D}));
354 DivJoins.emplace(CArg, SmallBlockVec({}));
355
356 CheckDivergenceFunc(*F, DivJoins);
357 }
358
359 {
360 auto *F = M->getFunction("f_3");
361 auto ItBlocks = F->begin();
362 ItBlocks++; // Skip A
363 ItBlocks++; // Skip B
364 auto *C = &*ItBlocks++;
365 ItBlocks++; // Skip D
366 auto *E = &*ItBlocks;
367
368 auto ItArg = F->arg_begin();
369 auto *AArg = &*ItArg++;
370 auto *BArg = &*ItArg++;
371 auto *CArg = &*ItArg;
372
373 InducedDivJoinMap DivJoins;
374 DivJoins.emplace(AArg, SmallBlockVec({C}));
375 DivJoins.emplace(BArg, SmallBlockVec({}));
376 DivJoins.emplace(CArg, SmallBlockVec({E}));
377
378 CheckDivergenceFunc(*F, DivJoins);
379 }
380}
381
382TEST_F(DivergenceAnalysisTest, DASwitchUnreachableDefault) {
383 LLVMContext C;
384 SMDiagnostic Err;
385
386 std::unique_ptr<Module> M = parseAssemblyString(
387 "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
388 " "
389 "define void @switch_unreachable_default(i32 %cond) local_unnamed_addr { "
390 "entry: "
391 " switch i32 %cond, label %sw.default [ "
392 " i32 0, label %sw.bb0 "
393 " i32 1, label %sw.bb1 "
394 " ] "
395 " "
396 "sw.bb0: "
397 " br label %sw.epilog "
398 " "
399 "sw.bb1: "
400 " br label %sw.epilog "
401 " "
402 "sw.default: "
403 " unreachable "
404 " "
405 "sw.epilog: "
406 " %div.dbl = phi double [ 0.0, %sw.bb0], [ -1.0, %sw.bb1 ] "
407 " ret void "
408 "}",
409 Err, C);
410
411 auto *F = M->getFunction("switch_unreachable_default");
412 auto &CondArg = *F->arg_begin();
413 auto DA = buildDA(*F, false);
414
415 EXPECT_FALSE(DA.hasDetectedDivergence());
416
417 DA.markDivergent(CondArg);
418 DA.compute();
419
420 // Still %CondArg is divergent.
421 EXPECT_TRUE(DA.hasDetectedDivergence());
422
423 // The join uni.dbl is not divergent (see D52221)
424 auto &ExitBlock = *GetBlockByName("sw.epilog", *F);
425 auto &DivDblPhi = *cast<PHINode>(ExitBlock.begin());
426 EXPECT_TRUE(DA.isDivergent(DivDblPhi));
427}
428
429} // end anonymous namespace
430} // end namespace llvm