blob: 14e180bb29b99d868190b3410a230a460e88780b [file] [log] [blame]
Uday Bondhugulab553adb2018-08-25 17:17:56 -07001//===- HyperRectangularSet.cpp - MLIR HyperRectangularSet Class--*- C++ -*-===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// Structures for affine/polyhedral analysis of MLIR functions.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/HyperRectangularSet.h"
23
24#include <algorithm>
25
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/IntegerSet.h"
28#include "llvm/Support/raw_ostream.h"
29
30using namespace mlir;
31
32// TODO(bondhugula): clean this code up.
33// Get the constant bound that is either the min or max (depending on 'cmp').
34static Optional<int64_t>
35getReducedConstBound(const HyperRectangularSet &set, unsigned *idx,
36 std::function<bool(int64_t, int64_t)> const &cmp) {
37 Optional<int64_t> val = None;
38
39 for (unsigned i = 0, n = set.getNumDims(); i < n; i++) {
40 auto &ubs = set.getLowerBound(i);
41 unsigned j = 0;
42 AffineBoundExprList::const_iterator it, e;
43 for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
44 if (auto *cExpr = dyn_cast<AffineConstantExpr>(*it)) {
45 if (val == None) {
46 val = cExpr->getValue();
47 *idx = j;
48 } else if (cmp(cExpr->getValue(), val.getValue())) {
49 val = cExpr->getValue();
50 *idx = j;
51 }
52 }
53 }
54 }
55 return val;
56}
57
58// Merge the two lists of AffineExpr's into a single one, avoiding duplicates.
59// lb specifies whether the bound lists are for a lower bound or an upper bound.
60// TODO(bondhugula): clean this code up.
61static void mergeBounds(const HyperRectangularSet &set,
62 AffineBoundExprList &lhsList,
63 const AffineBoundExprList &rhsList, bool lb) {
64 // The list of bounds is going to be small. Just a linear search
65 // should be enough to create a list without duplicates.
66 for (auto *expr : rhsList) {
67 AffineBoundExprList::const_iterator it;
68 for (it = lhsList.begin(); it != lhsList.end(); it++) {
69 if (expr == *it)
70 break;
71 }
72 if (it == lhsList.end()) {
73 // There can only be one constant affine expr in this bound list.
74 if (auto *cExpr = dyn_cast<AffineConstantExpr>(expr)) {
75 unsigned idx;
76 if (lb) {
77 auto cb = getReducedConstBound(
78 set, &idx,
79 [](int64_t newVal, int64_t oldVal) { return newVal < oldVal; });
80 if (!cb.hasValue()) {
81 lhsList.push_back(expr);
82 continue;
83 }
84 if (cExpr->getValue() < cb)
85 lhsList[idx] = expr;
86 // A constant value >= the existing bound constant.
87 continue;
88 }
89 // Upper bound case.
90 auto cb =
91 getReducedConstBound(set, &idx, [](int64_t newVal, int64_t oldVal) {
92 return newVal > oldVal;
93 });
94 if (!cb.hasValue()) {
95 lhsList.push_back(expr);
96 continue;
97 }
98 if (cExpr->getValue() > cb)
99 lhsList[idx] = expr;
100 continue;
101 }
102 // Not a constant expression; push it.
103 // TODO(bondhugula): check if this was implied by an existing symbolic
104 // expression or by the context.
105 lhsList.push_back(expr);
106 }
107 }
108}
109
110HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
111 ArrayRef<ArrayRef<AffineExpr *>> lbs,
112 ArrayRef<ArrayRef<AffineExpr *>> ubs,
Uday Bondhugula83a41c92018-08-30 17:35:15 -0700113 MLIRContext *context,
Uday Bondhugulab553adb2018-08-25 17:17:56 -0700114 IntegerSet *symbolContext)
Uday Bondhugula83a41c92018-08-30 17:35:15 -0700115 : context(symbolContext ? MutableIntegerSet(symbolContext, context)
116 : MutableIntegerSet(numDims, numSymbols, context)) {
Uday Bondhugulab553adb2018-08-25 17:17:56 -0700117 unsigned d = 0;
118 for (auto boundList : lbs) {
119 AffineBoundExprList lb;
120 for (auto *expr : boundList) {
121 assert(expr->isSymbolicOrConstant() &&
122 "bound expression should be symbolic or constant");
123 lb.push_back(expr);
124 }
125 mergeBounds(*this, lowerBounds[d++], lb, true);
126 }
127
128 d = 0;
129 for (auto boundList : ubs) {
130 AffineBoundExprList ub;
131 for (auto *expr : boundList) {
132 assert(expr->isSymbolicOrConstant() &&
133 "bound expression should be symbolic or constant");
134 ub.push_back(expr);
135 }
136 mergeBounds(*this, upperBounds[d++], ub, false);
137 }
138
139 simplifyUnderContext();
140}
141
142void HyperRectangularSet::projectOut(unsigned idx, unsigned num) {
143 // Erase the bounds along the projected out dimensions.
144 lowerBounds.erase(lowerBounds.begin() + idx, lowerBounds.begin() + idx + num);
145 upperBounds.erase(upperBounds.begin() + idx, upperBounds.begin() + idx + num);
146 numDims -= num;
147}
148
149void HyperRectangularSet::intersect(const HyperRectangularSet &rhs) {
150 assert(rhs.getNumSymbols() == getNumSymbols() &&
151 rhs.getNumDims() == getNumDims() && "operand space does not match");
152
153 // Intersection is just a concatenation of distinct bounds.
154 for (unsigned i = 0, n = getNumDims(); i < n; i++) {
155 mergeBounds(*this, getLowerBound(i), rhs.getLowerBound(i), true);
156 mergeBounds(*this, getUpperBound(i), rhs.getUpperBound(i), false);
157 }
158}
159
160void HyperRectangularSet::print(raw_ostream &os) const {
161 os << "Hyper rectangular set: " << numDims << "dimensions, " << numSymbols
162 << "symbols\n";
163 os << "Lower bounds\n";
164 unsigned d = 0;
165 for (auto &lb : lowerBounds) {
166 os << "Dim " << d++ << "\n";
167 for (auto *expr : lb) {
168 expr->print(os);
169 }
170 }
171 d = 0;
172 os << "Upper bounds\n";
173 for (auto &lb : upperBounds) {
174 os << "Dim " << d++ << "\n";
175 for (auto *expr : lb) {
176 expr->print(os);
177 }
178 }
179}
180
181void HyperRectangleList::projectOut(unsigned idx, unsigned num) {
182 for (auto &elt : hyperRectangles) {
183 elt.projectOut(idx, num);
184 }
185 // TODO: after a project out, some of the sets may be identical. Remove those.
186}
187
188bool HyperRectangleList::empty() const {
189 for (auto &set : hyperRectangles) {
190 if (!set.empty())
191 return false;
192 }
193 return true;
194}
195
196bool HyperRectangularSet::empty() const {
197 assert(0 && "unimplemented");
198 return false;
199}
200
201void HyperRectangularSet::dump() const { print(llvm::errs()); }