blob: dc41e5425d4ad58351c41cf0f50a1f69e72d1550 [file] [log] [blame]
Aditya Nandakumar2036f442018-01-25 02:53:06 +00001//===- PatternMatchTest.cpp -----------------------------------------------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
11#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
12#include "llvm/CodeGen/GlobalISel/Utils.h"
13#include "llvm/CodeGen/MIRParser/MIRParser.h"
14#include "llvm/CodeGen/MachineFunction.h"
15#include "llvm/CodeGen/MachineModuleInfo.h"
16#include "llvm/CodeGen/TargetFrameLowering.h"
17#include "llvm/CodeGen/TargetInstrInfo.h"
18#include "llvm/CodeGen/TargetLowering.h"
19#include "llvm/CodeGen/TargetSubtargetInfo.h"
20#include "llvm/Support/SourceMgr.h"
21#include "llvm/Support/TargetRegistry.h"
22#include "llvm/Support/TargetSelect.h"
23#include "llvm/Target/TargetMachine.h"
24#include "llvm/Target/TargetOptions.h"
25#include "gtest/gtest.h"
26
27using namespace llvm;
28using namespace MIPatternMatch;
29
30namespace {
31
32void initLLVM() {
33 InitializeAllTargets();
34 InitializeAllTargetMCs();
35 InitializeAllAsmPrinters();
36 InitializeAllAsmParsers();
37
38 PassRegistry *Registry = PassRegistry::getPassRegistry();
39 initializeCore(*Registry);
40 initializeCodeGen(*Registry);
41}
42
43/// Create a TargetMachine. As we lack a dedicated always available target for
44/// unittests, we go for "AArch64".
45std::unique_ptr<TargetMachine> createTargetMachine() {
46 Triple TargetTriple("aarch64--");
47 std::string Error;
48 const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
49 if (!T)
50 return nullptr;
51
52 TargetOptions Options;
53 return std::unique_ptr<TargetMachine>(T->createTargetMachine(
54 "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive));
55}
56
57std::unique_ptr<Module> parseMIR(LLVMContext &Context,
58 std::unique_ptr<MIRParser> &MIR,
59 const TargetMachine &TM, StringRef MIRCode,
60 const char *FuncName, MachineModuleInfo &MMI) {
61 SMDiagnostic Diagnostic;
62 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
63 MIR = createMIRParser(std::move(MBuffer), Context);
64 if (!MIR)
65 return nullptr;
66
67 std::unique_ptr<Module> M = MIR->parseIRModule();
68 if (!M)
69 return nullptr;
70
71 M->setDataLayout(TM.createDataLayout());
72
73 if (MIR->parseMachineFunctions(*M, MMI))
74 return nullptr;
75
76 return M;
77}
78
79std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
80createDummyModule(LLVMContext &Context, const TargetMachine &TM,
81 StringRef MIRFunc) {
82 SmallString<512> S;
83 StringRef MIRString = (Twine(R"MIR(
84---
85...
86name: func
87registers:
88 - { id: 0, class: _ }
89 - { id: 1, class: _ }
90 - { id: 2, class: _ }
91 - { id: 3, class: _ }
92body: |
93 bb.1:
Puyan Lotfi43e94b12018-01-31 22:04:26 +000094 %0(s64) = COPY $x0
95 %1(s64) = COPY $x1
96 %2(s64) = COPY $x2
Aditya Nandakumar2036f442018-01-25 02:53:06 +000097)MIR") + Twine(MIRFunc) + Twine("...\n"))
98 .toNullTerminatedStringRef(S);
99 std::unique_ptr<MIRParser> MIR;
100 auto MMI = make_unique<MachineModuleInfo>(&TM);
101 std::unique_ptr<Module> M =
102 parseMIR(Context, MIR, TM, MIRString, "func", *MMI);
103 return make_pair(std::move(M), std::move(MMI));
104}
105
106static MachineFunction *getMFFromMMI(const Module *M,
107 const MachineModuleInfo *MMI) {
108 Function *F = M->getFunction("func");
109 auto *MF = MMI->getMachineFunction(*F);
110 return MF;
111}
112
113static void collectCopies(SmallVectorImpl<unsigned> &Copies,
114 MachineFunction *MF) {
115 for (auto &MBB : *MF)
116 for (MachineInstr &MI : MBB) {
117 if (MI.getOpcode() == TargetOpcode::COPY)
118 Copies.push_back(MI.getOperand(0).getReg());
119 }
120}
121
122TEST(PatternMatchInstr, MatchIntConstant) {
123 LLVMContext Context;
124 std::unique_ptr<TargetMachine> TM = createTargetMachine();
125 if (!TM)
126 return;
127 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
128 MachineFunction *MF =
129 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
130 SmallVector<unsigned, 4> Copies;
131 collectCopies(Copies, MF);
132 MachineBasicBlock *EntryMBB = &*MF->begin();
133 MachineIRBuilder B(*MF);
134 MachineRegisterInfo &MRI = MF->getRegInfo();
135 B.setInsertPt(*EntryMBB, EntryMBB->end());
136 auto MIBCst = B.buildConstant(LLT::scalar(64), 42);
137 uint64_t Cst;
138 bool match = mi_match(MIBCst->getOperand(0).getReg(), MRI, m_ICst(Cst));
139 ASSERT_TRUE(match);
140 ASSERT_EQ(Cst, (uint64_t)42);
141}
142
143TEST(PatternMatchInstr, MatchBinaryOp) {
144 LLVMContext Context;
145 std::unique_ptr<TargetMachine> TM = createTargetMachine();
146 if (!TM)
147 return;
148 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
149 MachineFunction *MF =
150 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
151 SmallVector<unsigned, 4> Copies;
152 collectCopies(Copies, MF);
153 MachineBasicBlock *EntryMBB = &*MF->begin();
154 MachineIRBuilder B(*MF);
155 MachineRegisterInfo &MRI = MF->getRegInfo();
156 B.setInsertPt(*EntryMBB, EntryMBB->end());
157 LLT s64 = LLT::scalar(64);
158 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
159 // Test case for no bind.
160 bool match =
161 mi_match(MIBAdd->getOperand(0).getReg(), MRI, m_GAdd(m_Reg(), m_Reg()));
162 ASSERT_TRUE(match);
163 unsigned Src0, Src1, Src2;
164 match = mi_match(MIBAdd->getOperand(0).getReg(), MRI,
165 m_GAdd(m_Reg(Src0), m_Reg(Src1)));
166 ASSERT_TRUE(match);
167 ASSERT_EQ(Src0, Copies[0]);
168 ASSERT_EQ(Src1, Copies[1]);
169
170 // Build MUL(ADD %0, %1), %2
171 auto MIBMul = B.buildMul(s64, MIBAdd, Copies[2]);
172
173 // Try to match MUL.
174 match = mi_match(MIBMul->getOperand(0).getReg(), MRI,
175 m_GMul(m_Reg(Src0), m_Reg(Src1)));
176 ASSERT_TRUE(match);
177 ASSERT_EQ(Src0, MIBAdd->getOperand(0).getReg());
178 ASSERT_EQ(Src1, Copies[2]);
179
180 // Try to match MUL(ADD)
181 match = mi_match(MIBMul->getOperand(0).getReg(), MRI,
182 m_GMul(m_GAdd(m_Reg(Src0), m_Reg(Src1)), m_Reg(Src2)));
183 ASSERT_TRUE(match);
184 ASSERT_EQ(Src0, Copies[0]);
185 ASSERT_EQ(Src1, Copies[1]);
186 ASSERT_EQ(Src2, Copies[2]);
187
188 // Test Commutativity.
189 auto MIBMul2 = B.buildMul(s64, Copies[0], B.buildConstant(s64, 42));
190 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate
191 // commutativity.
192 uint64_t Cst;
193 match = mi_match(MIBMul2->getOperand(0).getReg(), MRI,
194 m_GMul(m_ICst(Cst), m_Reg(Src0)));
195 ASSERT_TRUE(match);
196 ASSERT_EQ(Cst, (uint64_t)42);
197 ASSERT_EQ(Src0, Copies[0]);
198
199 // Make sure commutative doesn't work with something like SUB.
200 auto MIBSub = B.buildSub(s64, Copies[0], B.buildConstant(s64, 42));
201 match = mi_match(MIBSub->getOperand(0).getReg(), MRI,
202 m_GSub(m_ICst(Cst), m_Reg(Src0)));
203 ASSERT_FALSE(match);
Aditya Nandakumar6250b182018-02-13 20:09:13 +0000204
205 auto MIBFMul = B.buildInstr(TargetOpcode::G_FMUL, s64, Copies[0],
206 B.buildConstant(s64, 42));
207 // Match and test commutativity for FMUL.
208 match = mi_match(MIBFMul->getOperand(0).getReg(), MRI,
209 m_GFMul(m_ICst(Cst), m_Reg(Src0)));
210 ASSERT_TRUE(match);
211 ASSERT_EQ(Cst, (uint64_t)42);
212 ASSERT_EQ(Src0, Copies[0]);
Aditya Nandakumar2036f442018-01-25 02:53:06 +0000213}
214
215TEST(PatternMatchInstr, MatchExtendsTrunc) {
216 LLVMContext Context;
217 std::unique_ptr<TargetMachine> TM = createTargetMachine();
218 if (!TM)
219 return;
220 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
221 MachineFunction *MF =
222 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
223 SmallVector<unsigned, 4> Copies;
224 collectCopies(Copies, MF);
225 MachineBasicBlock *EntryMBB = &*MF->begin();
226 MachineIRBuilder B(*MF);
227 MachineRegisterInfo &MRI = MF->getRegInfo();
228 B.setInsertPt(*EntryMBB, EntryMBB->end());
229 LLT s64 = LLT::scalar(64);
230 LLT s32 = LLT::scalar(32);
231
232 auto MIBTrunc = B.buildTrunc(s32, Copies[0]);
233 auto MIBAExt = B.buildAnyExt(s64, MIBTrunc);
234 auto MIBZExt = B.buildZExt(s64, MIBTrunc);
235 auto MIBSExt = B.buildSExt(s64, MIBTrunc);
236 unsigned Src0;
237 bool match =
238 mi_match(MIBTrunc->getOperand(0).getReg(), MRI, m_GTrunc(m_Reg(Src0)));
239 ASSERT_TRUE(match);
240 ASSERT_EQ(Src0, Copies[0]);
241 match =
242 mi_match(MIBAExt->getOperand(0).getReg(), MRI, m_GAnyExt(m_Reg(Src0)));
243 ASSERT_TRUE(match);
244 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
245
246 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, m_GSExt(m_Reg(Src0)));
247 ASSERT_TRUE(match);
248 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
249
250 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, m_GZExt(m_Reg(Src0)));
251 ASSERT_TRUE(match);
252 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
253
254 // Match ext(trunc src)
255 match = mi_match(MIBAExt->getOperand(0).getReg(), MRI,
256 m_GAnyExt(m_GTrunc(m_Reg(Src0))));
257 ASSERT_TRUE(match);
258 ASSERT_EQ(Src0, Copies[0]);
259
260 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI,
261 m_GSExt(m_GTrunc(m_Reg(Src0))));
262 ASSERT_TRUE(match);
263 ASSERT_EQ(Src0, Copies[0]);
264
265 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI,
266 m_GZExt(m_GTrunc(m_Reg(Src0))));
267 ASSERT_TRUE(match);
268 ASSERT_EQ(Src0, Copies[0]);
269}
270
271TEST(PatternMatchInstr, MatchSpecificType) {
272 LLVMContext Context;
273 std::unique_ptr<TargetMachine> TM = createTargetMachine();
274 if (!TM)
275 return;
276 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
277 MachineFunction *MF =
278 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
279 SmallVector<unsigned, 4> Copies;
280 collectCopies(Copies, MF);
281 MachineBasicBlock *EntryMBB = &*MF->begin();
282 MachineIRBuilder B(*MF);
283 MachineRegisterInfo &MRI = MF->getRegInfo();
284 B.setInsertPt(*EntryMBB, EntryMBB->end());
285 LLT s64 = LLT::scalar(64);
286 LLT s32 = LLT::scalar(32);
287 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
288
289 // Try to match a 64bit add.
290 ASSERT_FALSE(mi_match(MIBAdd->getOperand(0).getReg(), MRI,
291 m_GAdd(m_SpecificType(s32), m_Reg())));
292 ASSERT_TRUE(mi_match(MIBAdd->getOperand(0).getReg(), MRI,
293 m_GAdd(m_SpecificType(s64), m_Reg())));
294}
295
296TEST(PatternMatchInstr, MatchCombinators) {
297 LLVMContext Context;
298 std::unique_ptr<TargetMachine> TM = createTargetMachine();
299 if (!TM)
300 return;
301 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
302 MachineFunction *MF =
303 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
304 SmallVector<unsigned, 4> Copies;
305 collectCopies(Copies, MF);
306 MachineBasicBlock *EntryMBB = &*MF->begin();
307 MachineIRBuilder B(*MF);
308 MachineRegisterInfo &MRI = MF->getRegInfo();
309 B.setInsertPt(*EntryMBB, EntryMBB->end());
310 LLT s64 = LLT::scalar(64);
311 LLT s32 = LLT::scalar(32);
312 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
313 unsigned Src0, Src1;
314 bool match =
315 mi_match(MIBAdd->getOperand(0).getReg(), MRI,
316 m_all_of(m_SpecificType(s64), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
317 ASSERT_TRUE(match);
318 ASSERT_EQ(Src0, Copies[0]);
319 ASSERT_EQ(Src1, Copies[1]);
320 // Check for s32 (which should fail).
321 match =
322 mi_match(MIBAdd->getOperand(0).getReg(), MRI,
323 m_all_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
324 ASSERT_FALSE(match);
325 match =
326 mi_match(MIBAdd->getOperand(0).getReg(), MRI,
327 m_any_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
328 ASSERT_TRUE(match);
329 ASSERT_EQ(Src0, Copies[0]);
330 ASSERT_EQ(Src1, Copies[1]);
331}
332} // namespace
333
334int main(int argc, char **argv) {
335 ::testing::InitGoogleTest(&argc, argv);
336 initLLVM();
337 return RUN_ALL_TESTS();
338}