blob: 1909ce116dd5699e9f9564139fc8161defea46a8 [file] [log] [blame]
Amara Emerson829037a2019-06-08 00:05:17 +00001//===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains switch inst lowering optimizations and utilities for
10// codegen, so that it can be used for both SelectionDAG and GlobalISel.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/MachineJumpTableInfo.h"
15#include "llvm/CodeGen/SwitchLoweringUtils.h"
16
17using namespace llvm;
18using namespace SwitchCG;
19
20uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
21 unsigned First, unsigned Last) {
22 assert(Last >= First);
23 const APInt &LowCase = Clusters[First].Low->getValue();
24 const APInt &HighCase = Clusters[Last].High->getValue();
25 assert(LowCase.getBitWidth() == HighCase.getBitWidth());
26
27 // FIXME: A range of consecutive cases has 100% density, but only requires one
28 // comparison to lower. We should discriminate against such consecutive ranges
29 // in jump tables.
30
31 return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
32}
33
34uint64_t
35SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
36 unsigned First, unsigned Last) {
37 assert(Last >= First);
38 assert(TotalCases[Last] >= TotalCases[First]);
39 uint64_t NumCases =
40 TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
41 return NumCases;
42}
43
44void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
45 const SwitchInst *SI,
46 MachineBasicBlock *DefaultMBB) {
47#ifndef NDEBUG
48 // Clusters must be non-empty, sorted, and only contain Range clusters.
49 assert(!Clusters.empty());
50 for (CaseCluster &C : Clusters)
51 assert(C.Kind == CC_Range);
52 for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
53 assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
54#endif
55
56 if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
57 return;
58
59 const int64_t N = Clusters.size();
60 const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
61 const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
62
63 if (N < 2 || N < MinJumpTableEntries)
64 return;
65
66 // TotalCases[i]: Total nbr of cases in Clusters[0..i].
67 SmallVector<unsigned, 8> TotalCases(N);
68 for (unsigned i = 0; i < N; ++i) {
69 const APInt &Hi = Clusters[i].High->getValue();
70 const APInt &Lo = Clusters[i].Low->getValue();
71 TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
72 if (i != 0)
73 TotalCases[i] += TotalCases[i - 1];
74 }
75
76 // Cheap case: the whole range may be suitable for jump table.
77 uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
78 uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
79 assert(NumCases < UINT64_MAX / 100);
80 assert(Range >= NumCases);
81 if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
82 CaseCluster JTCluster;
83 if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
84 Clusters[0] = JTCluster;
85 Clusters.resize(1);
86 return;
87 }
88 }
89
90 // The algorithm below is not suitable for -O0.
91 if (TM->getOptLevel() == CodeGenOpt::None)
92 return;
93
94 // Split Clusters into minimum number of dense partitions. The algorithm uses
95 // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
96 // for the Case Statement'" (1994), but builds the MinPartitions array in
97 // reverse order to make it easier to reconstruct the partitions in ascending
98 // order. In the choice between two optimal partitionings, it picks the one
99 // which yields more jump tables.
100
101 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
102 SmallVector<unsigned, 8> MinPartitions(N);
103 // LastElement[i] is the last element of the partition starting at i.
104 SmallVector<unsigned, 8> LastElement(N);
105 // PartitionsScore[i] is used to break ties when choosing between two
106 // partitionings resulting in the same number of partitions.
107 SmallVector<unsigned, 8> PartitionsScore(N);
108 // For PartitionsScore, a small number of comparisons is considered as good as
109 // a jump table and a single comparison is considered better than a jump
110 // table.
111 enum PartitionScores : unsigned {
112 NoTable = 0,
113 Table = 1,
114 FewCases = 1,
115 SingleCase = 2
116 };
117
118 // Base case: There is only one way to partition Clusters[N-1].
119 MinPartitions[N - 1] = 1;
120 LastElement[N - 1] = N - 1;
121 PartitionsScore[N - 1] = PartitionScores::SingleCase;
122
123 // Note: loop indexes are signed to avoid underflow.
124 for (int64_t i = N - 2; i >= 0; i--) {
125 // Find optimal partitioning of Clusters[i..N-1].
126 // Baseline: Put Clusters[i] into a partition on its own.
127 MinPartitions[i] = MinPartitions[i + 1] + 1;
128 LastElement[i] = i;
129 PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
130
131 // Search for a solution that results in fewer partitions.
132 for (int64_t j = N - 1; j > i; j--) {
133 // Try building a partition from Clusters[i..j].
134 uint64_t Range = getJumpTableRange(Clusters, i, j);
135 uint64_t NumCases = getJumpTableNumCases(TotalCases, i, j);
136 assert(NumCases < UINT64_MAX / 100);
137 assert(Range >= NumCases);
138 if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
139 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
140 unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
141 int64_t NumEntries = j - i + 1;
142
143 if (NumEntries == 1)
144 Score += PartitionScores::SingleCase;
145 else if (NumEntries <= SmallNumberOfEntries)
146 Score += PartitionScores::FewCases;
147 else if (NumEntries >= MinJumpTableEntries)
148 Score += PartitionScores::Table;
149
150 // If this leads to fewer partitions, or to the same number of
151 // partitions with better score, it is a better partitioning.
152 if (NumPartitions < MinPartitions[i] ||
153 (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
154 MinPartitions[i] = NumPartitions;
155 LastElement[i] = j;
156 PartitionsScore[i] = Score;
157 }
158 }
159 }
160 }
161
162 // Iterate over the partitions, replacing some with jump tables in-place.
163 unsigned DstIndex = 0;
164 for (unsigned First = 0, Last; First < N; First = Last + 1) {
165 Last = LastElement[First];
166 assert(Last >= First);
167 assert(DstIndex <= First);
168 unsigned NumClusters = Last - First + 1;
169
170 CaseCluster JTCluster;
171 if (NumClusters >= MinJumpTableEntries &&
172 buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
173 Clusters[DstIndex++] = JTCluster;
174 } else {
175 for (unsigned I = First; I <= Last; ++I)
176 std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
177 }
178 }
179 Clusters.resize(DstIndex);
180}
181
182bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
183 unsigned First, unsigned Last,
184 const SwitchInst *SI,
185 MachineBasicBlock *DefaultMBB,
186 CaseCluster &JTCluster) {
187 assert(First <= Last);
188
189 auto Prob = BranchProbability::getZero();
190 unsigned NumCmps = 0;
191 std::vector<MachineBasicBlock*> Table;
192 DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
193
194 // Initialize probabilities in JTProbs.
195 for (unsigned I = First; I <= Last; ++I)
196 JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
197
198 for (unsigned I = First; I <= Last; ++I) {
199 assert(Clusters[I].Kind == CC_Range);
200 Prob += Clusters[I].Prob;
201 const APInt &Low = Clusters[I].Low->getValue();
202 const APInt &High = Clusters[I].High->getValue();
203 NumCmps += (Low == High) ? 1 : 2;
204 if (I != First) {
205 // Fill the gap between this and the previous cluster.
206 const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
207 assert(PreviousHigh.slt(Low));
208 uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
209 for (uint64_t J = 0; J < Gap; J++)
210 Table.push_back(DefaultMBB);
211 }
212 uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
213 for (uint64_t J = 0; J < ClusterSize; ++J)
214 Table.push_back(Clusters[I].MBB);
215 JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
216 }
217
218 unsigned NumDests = JTProbs.size();
219 if (TLI->isSuitableForBitTests(NumDests, NumCmps,
220 Clusters[First].Low->getValue(),
221 Clusters[Last].High->getValue(), *DL)) {
222 // Clusters[First..Last] should be lowered as bit tests instead.
223 return false;
224 }
225
226 // Create the MBB that will load from and jump through the table.
227 // Note: We create it here, but it's not inserted into the function yet.
228 MachineFunction *CurMF = FuncInfo.MF;
229 MachineBasicBlock *JumpTableMBB =
230 CurMF->CreateMachineBasicBlock(SI->getParent());
231
232 // Add successors. Note: use table order for determinism.
233 SmallPtrSet<MachineBasicBlock *, 8> Done;
234 for (MachineBasicBlock *Succ : Table) {
235 if (Done.count(Succ))
236 continue;
237 addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
238 Done.insert(Succ);
239 }
240 JumpTableMBB->normalizeSuccProbs();
241
242 unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
243 ->createJumpTableIndex(Table);
244
245 // Set up the jump table info.
246 JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
247 JumpTableHeader JTH(Clusters[First].Low->getValue(),
248 Clusters[Last].High->getValue(), SI->getCondition(),
249 nullptr, false);
250 JTCases.emplace_back(std::move(JTH), std::move(JT));
251
252 JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
253 JTCases.size() - 1, Prob);
254 return true;
255}
256
257void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
258 const SwitchInst *SI) {
259 // Partition Clusters into as few subsets as possible, where each subset has a
260 // range that fits in a machine word and has <= 3 unique destinations.
261
262#ifndef NDEBUG
263 // Clusters must be sorted and contain Range or JumpTable clusters.
264 assert(!Clusters.empty());
265 assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
266 for (const CaseCluster &C : Clusters)
267 assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
268 for (unsigned i = 1; i < Clusters.size(); ++i)
269 assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
270#endif
271
272 // The algorithm below is not suitable for -O0.
273 if (TM->getOptLevel() == CodeGenOpt::None)
274 return;
275
276 // If target does not have legal shift left, do not emit bit tests at all.
277 EVT PTy = TLI->getPointerTy(*DL);
278 if (!TLI->isOperationLegal(ISD::SHL, PTy))
279 return;
280
281 int BitWidth = PTy.getSizeInBits();
282 const int64_t N = Clusters.size();
283
284 // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
285 SmallVector<unsigned, 8> MinPartitions(N);
286 // LastElement[i] is the last element of the partition starting at i.
287 SmallVector<unsigned, 8> LastElement(N);
288
289 // FIXME: This might not be the best algorithm for finding bit test clusters.
290
291 // Base case: There is only one way to partition Clusters[N-1].
292 MinPartitions[N - 1] = 1;
293 LastElement[N - 1] = N - 1;
294
295 // Note: loop indexes are signed to avoid underflow.
296 for (int64_t i = N - 2; i >= 0; --i) {
297 // Find optimal partitioning of Clusters[i..N-1].
298 // Baseline: Put Clusters[i] into a partition on its own.
299 MinPartitions[i] = MinPartitions[i + 1] + 1;
300 LastElement[i] = i;
301
302 // Search for a solution that results in fewer partitions.
303 // Note: the search is limited by BitWidth, reducing time complexity.
304 for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
305 // Try building a partition from Clusters[i..j].
306
307 // Check the range.
308 if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
309 Clusters[j].High->getValue(), *DL))
310 continue;
311
312 // Check nbr of destinations and cluster types.
313 // FIXME: This works, but doesn't seem very efficient.
314 bool RangesOnly = true;
315 BitVector Dests(FuncInfo.MF->getNumBlockIDs());
316 for (int64_t k = i; k <= j; k++) {
317 if (Clusters[k].Kind != CC_Range) {
318 RangesOnly = false;
319 break;
320 }
321 Dests.set(Clusters[k].MBB->getNumber());
322 }
323 if (!RangesOnly || Dests.count() > 3)
324 break;
325
326 // Check if it's a better partition.
327 unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
328 if (NumPartitions < MinPartitions[i]) {
329 // Found a better partition.
330 MinPartitions[i] = NumPartitions;
331 LastElement[i] = j;
332 }
333 }
334 }
335
336 // Iterate over the partitions, replacing with bit-test clusters in-place.
337 unsigned DstIndex = 0;
338 for (unsigned First = 0, Last; First < N; First = Last + 1) {
339 Last = LastElement[First];
340 assert(First <= Last);
341 assert(DstIndex <= First);
342
343 CaseCluster BitTestCluster;
344 if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
345 Clusters[DstIndex++] = BitTestCluster;
346 } else {
347 size_t NumClusters = Last - First + 1;
348 std::memmove(&Clusters[DstIndex], &Clusters[First],
349 sizeof(Clusters[0]) * NumClusters);
350 DstIndex += NumClusters;
351 }
352 }
353 Clusters.resize(DstIndex);
354}
355
356bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
357 unsigned First, unsigned Last,
358 const SwitchInst *SI,
359 CaseCluster &BTCluster) {
360 assert(First <= Last);
361 if (First == Last)
362 return false;
363
364 BitVector Dests(FuncInfo.MF->getNumBlockIDs());
365 unsigned NumCmps = 0;
366 for (int64_t I = First; I <= Last; ++I) {
367 assert(Clusters[I].Kind == CC_Range);
368 Dests.set(Clusters[I].MBB->getNumber());
369 NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
370 }
371 unsigned NumDests = Dests.count();
372
373 APInt Low = Clusters[First].Low->getValue();
374 APInt High = Clusters[Last].High->getValue();
375 assert(Low.slt(High));
376
377 if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
378 return false;
379
380 APInt LowBound;
381 APInt CmpRange;
382
383 const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
384 assert(TLI->rangeFitsInWord(Low, High, *DL) &&
385 "Case range must fit in bit mask!");
386
387 // Check if the clusters cover a contiguous range such that no value in the
388 // range will jump to the default statement.
389 bool ContiguousRange = true;
390 for (int64_t I = First + 1; I <= Last; ++I) {
391 if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
392 ContiguousRange = false;
393 break;
394 }
395 }
396
397 if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
398 // Optimize the case where all the case values fit in a word without having
399 // to subtract minValue. In this case, we can optimize away the subtraction.
400 LowBound = APInt::getNullValue(Low.getBitWidth());
401 CmpRange = High;
402 ContiguousRange = false;
403 } else {
404 LowBound = Low;
405 CmpRange = High - Low;
406 }
407
408 CaseBitsVector CBV;
409 auto TotalProb = BranchProbability::getZero();
410 for (unsigned i = First; i <= Last; ++i) {
411 // Find the CaseBits for this destination.
412 unsigned j;
413 for (j = 0; j < CBV.size(); ++j)
414 if (CBV[j].BB == Clusters[i].MBB)
415 break;
416 if (j == CBV.size())
417 CBV.push_back(
418 CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
419 CaseBits *CB = &CBV[j];
420
421 // Update Mask, Bits and ExtraProb.
422 uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
423 uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
424 assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
425 CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
426 CB->Bits += Hi - Lo + 1;
427 CB->ExtraProb += Clusters[i].Prob;
428 TotalProb += Clusters[i].Prob;
429 }
430
431 BitTestInfo BTI;
432 llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
433 // Sort by probability first, number of bits second, bit mask third.
434 if (a.ExtraProb != b.ExtraProb)
435 return a.ExtraProb > b.ExtraProb;
436 if (a.Bits != b.Bits)
437 return a.Bits > b.Bits;
438 return a.Mask < b.Mask;
439 });
440
441 for (auto &CB : CBV) {
442 MachineBasicBlock *BitTestBB =
443 FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
444 BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
445 }
446 BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
447 SI->getCondition(), -1U, MVT::Other, false,
448 ContiguousRange, nullptr, nullptr, std::move(BTI),
449 TotalProb);
450
451 BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
452 BitTestCases.size() - 1, TotalProb);
453 return true;
454}
455
456void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
457#ifndef NDEBUG
458 for (const CaseCluster &CC : Clusters)
459 assert(CC.Low == CC.High && "Input clusters must be single-case");
460#endif
461
462 llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
463 return a.Low->getValue().slt(b.Low->getValue());
464 });
465
466 // Merge adjacent clusters with the same destination.
467 const unsigned N = Clusters.size();
468 unsigned DstIndex = 0;
469 for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
470 CaseCluster &CC = Clusters[SrcIndex];
471 const ConstantInt *CaseVal = CC.Low;
472 MachineBasicBlock *Succ = CC.MBB;
473
474 if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
475 (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
476 // If this case has the same successor and is a neighbour, merge it into
477 // the previous cluster.
478 Clusters[DstIndex - 1].High = CaseVal;
479 Clusters[DstIndex - 1].Prob += CC.Prob;
480 } else {
481 std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
482 sizeof(Clusters[SrcIndex]));
483 }
484 }
485 Clusters.resize(DstIndex);
486}