blob: dd26e9b11fedca3cde02656dbaea54e3e265b3b3 [file] [log] [blame]
Aditya Nandakumar2036f442018-01-25 02:53:06 +00001//===- PatternMatchTest.cpp -----------------------------------------------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
11#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
12#include "llvm/CodeGen/GlobalISel/Utils.h"
13#include "llvm/CodeGen/MIRParser/MIRParser.h"
14#include "llvm/CodeGen/MachineFunction.h"
15#include "llvm/CodeGen/MachineModuleInfo.h"
16#include "llvm/CodeGen/TargetFrameLowering.h"
17#include "llvm/CodeGen/TargetInstrInfo.h"
18#include "llvm/CodeGen/TargetLowering.h"
19#include "llvm/CodeGen/TargetSubtargetInfo.h"
20#include "llvm/Support/SourceMgr.h"
21#include "llvm/Support/TargetRegistry.h"
22#include "llvm/Support/TargetSelect.h"
23#include "llvm/Target/TargetMachine.h"
24#include "llvm/Target/TargetOptions.h"
25#include "gtest/gtest.h"
26
27using namespace llvm;
28using namespace MIPatternMatch;
29
30namespace {
31
32void initLLVM() {
33 InitializeAllTargets();
34 InitializeAllTargetMCs();
35 InitializeAllAsmPrinters();
36 InitializeAllAsmParsers();
37
38 PassRegistry *Registry = PassRegistry::getPassRegistry();
39 initializeCore(*Registry);
40 initializeCodeGen(*Registry);
41}
42
43/// Create a TargetMachine. As we lack a dedicated always available target for
44/// unittests, we go for "AArch64".
45std::unique_ptr<TargetMachine> createTargetMachine() {
46 Triple TargetTriple("aarch64--");
47 std::string Error;
48 const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
49 if (!T)
50 return nullptr;
51
52 TargetOptions Options;
53 return std::unique_ptr<TargetMachine>(T->createTargetMachine(
54 "AArch64", "", "", Options, None, None, CodeGenOpt::Aggressive));
55}
56
57std::unique_ptr<Module> parseMIR(LLVMContext &Context,
58 std::unique_ptr<MIRParser> &MIR,
59 const TargetMachine &TM, StringRef MIRCode,
60 const char *FuncName, MachineModuleInfo &MMI) {
61 SMDiagnostic Diagnostic;
62 std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
63 MIR = createMIRParser(std::move(MBuffer), Context);
64 if (!MIR)
65 return nullptr;
66
67 std::unique_ptr<Module> M = MIR->parseIRModule();
68 if (!M)
69 return nullptr;
70
71 M->setDataLayout(TM.createDataLayout());
72
73 if (MIR->parseMachineFunctions(*M, MMI))
74 return nullptr;
75
76 return M;
77}
78
79std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
80createDummyModule(LLVMContext &Context, const TargetMachine &TM,
81 StringRef MIRFunc) {
82 SmallString<512> S;
83 StringRef MIRString = (Twine(R"MIR(
84---
85...
86name: func
87registers:
88 - { id: 0, class: _ }
89 - { id: 1, class: _ }
90 - { id: 2, class: _ }
91 - { id: 3, class: _ }
92body: |
93 bb.1:
Puyan Lotfi43e94b12018-01-31 22:04:26 +000094 %0(s64) = COPY $x0
95 %1(s64) = COPY $x1
96 %2(s64) = COPY $x2
Aditya Nandakumar2036f442018-01-25 02:53:06 +000097)MIR") + Twine(MIRFunc) + Twine("...\n"))
98 .toNullTerminatedStringRef(S);
99 std::unique_ptr<MIRParser> MIR;
100 auto MMI = make_unique<MachineModuleInfo>(&TM);
101 std::unique_ptr<Module> M =
102 parseMIR(Context, MIR, TM, MIRString, "func", *MMI);
103 return make_pair(std::move(M), std::move(MMI));
104}
105
106static MachineFunction *getMFFromMMI(const Module *M,
107 const MachineModuleInfo *MMI) {
108 Function *F = M->getFunction("func");
109 auto *MF = MMI->getMachineFunction(*F);
110 return MF;
111}
112
113static void collectCopies(SmallVectorImpl<unsigned> &Copies,
114 MachineFunction *MF) {
115 for (auto &MBB : *MF)
116 for (MachineInstr &MI : MBB) {
117 if (MI.getOpcode() == TargetOpcode::COPY)
118 Copies.push_back(MI.getOperand(0).getReg());
119 }
120}
121
122TEST(PatternMatchInstr, MatchIntConstant) {
123 LLVMContext Context;
124 std::unique_ptr<TargetMachine> TM = createTargetMachine();
125 if (!TM)
126 return;
127 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
128 MachineFunction *MF =
129 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
130 SmallVector<unsigned, 4> Copies;
131 collectCopies(Copies, MF);
132 MachineBasicBlock *EntryMBB = &*MF->begin();
133 MachineIRBuilder B(*MF);
134 MachineRegisterInfo &MRI = MF->getRegInfo();
135 B.setInsertPt(*EntryMBB, EntryMBB->end());
136 auto MIBCst = B.buildConstant(LLT::scalar(64), 42);
137 uint64_t Cst;
138 bool match = mi_match(MIBCst->getOperand(0).getReg(), MRI, m_ICst(Cst));
139 ASSERT_TRUE(match);
140 ASSERT_EQ(Cst, (uint64_t)42);
141}
142
143TEST(PatternMatchInstr, MatchBinaryOp) {
144 LLVMContext Context;
145 std::unique_ptr<TargetMachine> TM = createTargetMachine();
146 if (!TM)
147 return;
148 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
149 MachineFunction *MF =
150 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
151 SmallVector<unsigned, 4> Copies;
152 collectCopies(Copies, MF);
153 MachineBasicBlock *EntryMBB = &*MF->begin();
154 MachineIRBuilder B(*MF);
155 MachineRegisterInfo &MRI = MF->getRegInfo();
156 B.setInsertPt(*EntryMBB, EntryMBB->end());
157 LLT s64 = LLT::scalar(64);
158 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
159 // Test case for no bind.
160 bool match =
161 mi_match(MIBAdd->getOperand(0).getReg(), MRI, m_GAdd(m_Reg(), m_Reg()));
162 ASSERT_TRUE(match);
163 unsigned Src0, Src1, Src2;
164 match = mi_match(MIBAdd->getOperand(0).getReg(), MRI,
165 m_GAdd(m_Reg(Src0), m_Reg(Src1)));
166 ASSERT_TRUE(match);
167 ASSERT_EQ(Src0, Copies[0]);
168 ASSERT_EQ(Src1, Copies[1]);
169
170 // Build MUL(ADD %0, %1), %2
171 auto MIBMul = B.buildMul(s64, MIBAdd, Copies[2]);
172
173 // Try to match MUL.
174 match = mi_match(MIBMul->getOperand(0).getReg(), MRI,
175 m_GMul(m_Reg(Src0), m_Reg(Src1)));
176 ASSERT_TRUE(match);
177 ASSERT_EQ(Src0, MIBAdd->getOperand(0).getReg());
178 ASSERT_EQ(Src1, Copies[2]);
179
180 // Try to match MUL(ADD)
181 match = mi_match(MIBMul->getOperand(0).getReg(), MRI,
182 m_GMul(m_GAdd(m_Reg(Src0), m_Reg(Src1)), m_Reg(Src2)));
183 ASSERT_TRUE(match);
184 ASSERT_EQ(Src0, Copies[0]);
185 ASSERT_EQ(Src1, Copies[1]);
186 ASSERT_EQ(Src2, Copies[2]);
187
188 // Test Commutativity.
189 auto MIBMul2 = B.buildMul(s64, Copies[0], B.buildConstant(s64, 42));
190 // Try to match MUL(Cst, Reg) on src of MUL(Reg, Cst) to validate
191 // commutativity.
192 uint64_t Cst;
193 match = mi_match(MIBMul2->getOperand(0).getReg(), MRI,
194 m_GMul(m_ICst(Cst), m_Reg(Src0)));
195 ASSERT_TRUE(match);
196 ASSERT_EQ(Cst, (uint64_t)42);
197 ASSERT_EQ(Src0, Copies[0]);
198
199 // Make sure commutative doesn't work with something like SUB.
200 auto MIBSub = B.buildSub(s64, Copies[0], B.buildConstant(s64, 42));
201 match = mi_match(MIBSub->getOperand(0).getReg(), MRI,
202 m_GSub(m_ICst(Cst), m_Reg(Src0)));
203 ASSERT_FALSE(match);
204}
205
206TEST(PatternMatchInstr, MatchExtendsTrunc) {
207 LLVMContext Context;
208 std::unique_ptr<TargetMachine> TM = createTargetMachine();
209 if (!TM)
210 return;
211 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
212 MachineFunction *MF =
213 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
214 SmallVector<unsigned, 4> Copies;
215 collectCopies(Copies, MF);
216 MachineBasicBlock *EntryMBB = &*MF->begin();
217 MachineIRBuilder B(*MF);
218 MachineRegisterInfo &MRI = MF->getRegInfo();
219 B.setInsertPt(*EntryMBB, EntryMBB->end());
220 LLT s64 = LLT::scalar(64);
221 LLT s32 = LLT::scalar(32);
222
223 auto MIBTrunc = B.buildTrunc(s32, Copies[0]);
224 auto MIBAExt = B.buildAnyExt(s64, MIBTrunc);
225 auto MIBZExt = B.buildZExt(s64, MIBTrunc);
226 auto MIBSExt = B.buildSExt(s64, MIBTrunc);
227 unsigned Src0;
228 bool match =
229 mi_match(MIBTrunc->getOperand(0).getReg(), MRI, m_GTrunc(m_Reg(Src0)));
230 ASSERT_TRUE(match);
231 ASSERT_EQ(Src0, Copies[0]);
232 match =
233 mi_match(MIBAExt->getOperand(0).getReg(), MRI, m_GAnyExt(m_Reg(Src0)));
234 ASSERT_TRUE(match);
235 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
236
237 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI, m_GSExt(m_Reg(Src0)));
238 ASSERT_TRUE(match);
239 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
240
241 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI, m_GZExt(m_Reg(Src0)));
242 ASSERT_TRUE(match);
243 ASSERT_EQ(Src0, MIBTrunc->getOperand(0).getReg());
244
245 // Match ext(trunc src)
246 match = mi_match(MIBAExt->getOperand(0).getReg(), MRI,
247 m_GAnyExt(m_GTrunc(m_Reg(Src0))));
248 ASSERT_TRUE(match);
249 ASSERT_EQ(Src0, Copies[0]);
250
251 match = mi_match(MIBSExt->getOperand(0).getReg(), MRI,
252 m_GSExt(m_GTrunc(m_Reg(Src0))));
253 ASSERT_TRUE(match);
254 ASSERT_EQ(Src0, Copies[0]);
255
256 match = mi_match(MIBZExt->getOperand(0).getReg(), MRI,
257 m_GZExt(m_GTrunc(m_Reg(Src0))));
258 ASSERT_TRUE(match);
259 ASSERT_EQ(Src0, Copies[0]);
260}
261
262TEST(PatternMatchInstr, MatchSpecificType) {
263 LLVMContext Context;
264 std::unique_ptr<TargetMachine> TM = createTargetMachine();
265 if (!TM)
266 return;
267 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
268 MachineFunction *MF =
269 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
270 SmallVector<unsigned, 4> Copies;
271 collectCopies(Copies, MF);
272 MachineBasicBlock *EntryMBB = &*MF->begin();
273 MachineIRBuilder B(*MF);
274 MachineRegisterInfo &MRI = MF->getRegInfo();
275 B.setInsertPt(*EntryMBB, EntryMBB->end());
276 LLT s64 = LLT::scalar(64);
277 LLT s32 = LLT::scalar(32);
278 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
279
280 // Try to match a 64bit add.
281 ASSERT_FALSE(mi_match(MIBAdd->getOperand(0).getReg(), MRI,
282 m_GAdd(m_SpecificType(s32), m_Reg())));
283 ASSERT_TRUE(mi_match(MIBAdd->getOperand(0).getReg(), MRI,
284 m_GAdd(m_SpecificType(s64), m_Reg())));
285}
286
287TEST(PatternMatchInstr, MatchCombinators) {
288 LLVMContext Context;
289 std::unique_ptr<TargetMachine> TM = createTargetMachine();
290 if (!TM)
291 return;
292 auto ModuleMMIPair = createDummyModule(Context, *TM, "");
293 MachineFunction *MF =
294 getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
295 SmallVector<unsigned, 4> Copies;
296 collectCopies(Copies, MF);
297 MachineBasicBlock *EntryMBB = &*MF->begin();
298 MachineIRBuilder B(*MF);
299 MachineRegisterInfo &MRI = MF->getRegInfo();
300 B.setInsertPt(*EntryMBB, EntryMBB->end());
301 LLT s64 = LLT::scalar(64);
302 LLT s32 = LLT::scalar(32);
303 auto MIBAdd = B.buildAdd(s64, Copies[0], Copies[1]);
304 unsigned Src0, Src1;
305 bool match =
306 mi_match(MIBAdd->getOperand(0).getReg(), MRI,
307 m_all_of(m_SpecificType(s64), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
308 ASSERT_TRUE(match);
309 ASSERT_EQ(Src0, Copies[0]);
310 ASSERT_EQ(Src1, Copies[1]);
311 // Check for s32 (which should fail).
312 match =
313 mi_match(MIBAdd->getOperand(0).getReg(), MRI,
314 m_all_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
315 ASSERT_FALSE(match);
316 match =
317 mi_match(MIBAdd->getOperand(0).getReg(), MRI,
318 m_any_of(m_SpecificType(s32), m_GAdd(m_Reg(Src0), m_Reg(Src1))));
319 ASSERT_TRUE(match);
320 ASSERT_EQ(Src0, Copies[0]);
321 ASSERT_EQ(Src1, Copies[1]);
322}
323} // namespace
324
325int main(int argc, char **argv) {
326 ::testing::InitGoogleTest(&argc, argv);
327 initLLVM();
328 return RUN_ALL_TESTS();
329}