blob: 85ff432c88b968dc63f2950d9a061cbfb7b7a608 [file] [log] [blame]
Chris Lattnerf7e22732018-06-22 22:03:48 -07001//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLIRContext.h"
Chris Lattnered65a732018-06-28 20:45:33 -070019#include "mlir/IR/Identifier.h"
Uday Bondhugulafaf37dd2018-06-29 18:09:29 -070020#include "mlir/IR/AffineExpr.h"
21#include "mlir/IR/AffineMap.h"
Chris Lattnerf7e22732018-06-22 22:03:48 -070022#include "mlir/IR/Types.h"
23#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/DenseSet.h"
Chris Lattnered65a732018-06-28 20:45:33 -070025#include "llvm/ADT/StringMap.h"
Chris Lattnerf7e22732018-06-22 22:03:48 -070026#include "llvm/Support/Allocator.h"
27using namespace mlir;
28using namespace llvm;
29
30namespace {
31struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
32 // Functions are uniqued based on their inputs and results.
33 using KeyTy = std::pair<ArrayRef<Type*>, ArrayRef<Type*>>;
34 using DenseMapInfo<FunctionType*>::getHashValue;
35 using DenseMapInfo<FunctionType*>::isEqual;
36
37 static unsigned getHashValue(KeyTy key) {
38 return hash_combine(hash_combine_range(key.first.begin(), key.first.end()),
39 hash_combine_range(key.second.begin(),
40 key.second.end()));
41 }
42
43 static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
44 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
45 return false;
46 return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
47 }
48};
Uday Bondhugulafaf37dd2018-06-29 18:09:29 -070049struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
50 // Affine maps are uniqued based on their arguments and affine expressions
51 using KeyTy = std::pair<unsigned, unsigned>;
52 using DenseMapInfo<AffineMap *>::getHashValue;
53 using DenseMapInfo<AffineMap *>::isEqual;
54
55 static unsigned getHashValue(KeyTy key) {
56 // FIXME(bondhugula): placeholder for now
57 return hash_combine(key.first, key.second);
58 }
59
60 static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
61 // TODO(bondhugula)
62 return false;
63 }
64};
65
Chris Lattnerf7e22732018-06-22 22:03:48 -070066struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
67 // Vectors are uniqued based on their element type and shape.
68 using KeyTy = std::pair<Type*, ArrayRef<unsigned>>;
69 using DenseMapInfo<VectorType*>::getHashValue;
70 using DenseMapInfo<VectorType*>::isEqual;
71
72 static unsigned getHashValue(KeyTy key) {
73 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
74 hash_combine_range(key.second.begin(),
75 key.second.end()));
76 }
77
78 static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
79 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
80 return false;
81 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
82 }
83};
MLIR Team355ec862018-06-23 18:09:09 -070084struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
85 // Ranked tensors are uniqued based on their element type and shape.
86 using KeyTy = std::pair<Type*, ArrayRef<int>>;
87 using DenseMapInfo<RankedTensorType*>::getHashValue;
88 using DenseMapInfo<RankedTensorType*>::isEqual;
89
90 static unsigned getHashValue(KeyTy key) {
91 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
92 hash_combine_range(key.second.begin(),
93 key.second.end()));
94 }
95
96 static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
97 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
98 return false;
99 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
100 }
101};
Chris Lattnerf7e22732018-06-22 22:03:48 -0700102} // end anonymous namespace.
103
104
105namespace mlir {
106/// This is the implementation of the MLIRContext class, using the pImpl idiom.
107/// This class is completely private to this file, so everything is public.
108class MLIRContextImpl {
109public:
110 /// We put immortal objects into this allocator.
111 llvm::BumpPtrAllocator allocator;
112
Chris Lattnered65a732018-06-28 20:45:33 -0700113 /// These are identifiers uniqued into this MLIRContext.
114 llvm::StringMap<char, llvm::BumpPtrAllocator&> identifiers;
115
Chris Lattnerf7e22732018-06-22 22:03:48 -0700116 // Primitive type uniquing.
117 PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
118
Uday Bondhugulafaf37dd2018-06-29 18:09:29 -0700119 // Affine map uniquing.
120 using AffineMapSet = DenseSet<AffineMap *, AffineMapKeyInfo>;
121 AffineMapSet affineMaps;
122
Chris Lattnerf958bbe2018-06-29 22:08:05 -0700123 /// Integer type uniquing.
124 DenseMap<unsigned, IntegerType*> integers;
125
Chris Lattnerf7e22732018-06-22 22:03:48 -0700126 /// Function type uniquing.
127 using FunctionTypeSet = DenseSet<FunctionType*, FunctionTypeKeyInfo>;
128 FunctionTypeSet functions;
129
130 /// Vector type uniquing.
131 using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
132 VectorTypeSet vectors;
133
MLIR Team355ec862018-06-23 18:09:09 -0700134 /// Ranked tensor type uniquing.
135 using RankedTensorTypeSet = DenseSet<RankedTensorType*,
136 RankedTensorTypeKeyInfo>;
137 RankedTensorTypeSet rankedTensors;
138
139 /// Unranked tensor type uniquing.
140 DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
141
Chris Lattnerf7e22732018-06-22 22:03:48 -0700142
143public:
Chris Lattnered65a732018-06-28 20:45:33 -0700144 MLIRContextImpl() : identifiers(allocator) {}
145
Chris Lattnerf7e22732018-06-22 22:03:48 -0700146 /// Copy the specified array of elements into memory managed by our bump
147 /// pointer allocator. This assumes the elements are all PODs.
148 template<typename T>
149 ArrayRef<T> copyInto(ArrayRef<T> elements) {
150 auto result = allocator.Allocate<T>(elements.size());
151 std::uninitialized_copy(elements.begin(), elements.end(), result);
152 return ArrayRef<T>(result, elements.size());
153 }
154};
155} // end namespace mlir
156
157MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
158}
159
160MLIRContext::~MLIRContext() {
161}
162
163
Chris Lattnered65a732018-06-28 20:45:33 -0700164//===----------------------------------------------------------------------===//
165// Identifier
166//===----------------------------------------------------------------------===//
167
168/// Return an identifier for the specified string.
169Identifier Identifier::get(StringRef str, const MLIRContext *context) {
170 assert(!str.empty() && "Cannot create an empty identifier");
171 assert(str.find('\0') == StringRef::npos &&
172 "Cannot create an identifier with a nul character");
173
174 auto &impl = context->getImpl();
175 auto it = impl.identifiers.insert({str, char()}).first;
176 return Identifier(it->getKeyData());
177}
178
Chris Lattnered65a732018-06-28 20:45:33 -0700179//===----------------------------------------------------------------------===//
180// Types
181//===----------------------------------------------------------------------===//
182
Chris Lattnerf7e22732018-06-22 22:03:48 -0700183PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
184 assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
185 auto &impl = context->getImpl();
186
187 // We normally have these types.
188 if (impl.primitives[(int)kind])
189 return impl.primitives[(int)kind];
190
191 // On the first use, we allocate them into the bump pointer.
192 auto *ptr = impl.allocator.Allocate<PrimitiveType>();
193
194 // Initialize the memory using placement new.
195 new(ptr) PrimitiveType(kind, context);
196
197 // Cache and return it.
198 return impl.primitives[(int)kind] = ptr;
199}
200
Chris Lattnerf958bbe2018-06-29 22:08:05 -0700201IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
202 auto &impl = context->getImpl();
203
204 auto *&result = impl.integers[width];
205 if (!result) {
206 result = impl.allocator.Allocate<IntegerType>();
207 new (result) IntegerType(width, context);
208 }
209
210 return result;
Chris Lattnerf7e22732018-06-22 22:03:48 -0700211}
212
213FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
214 MLIRContext *context) {
215 auto &impl = context->getImpl();
216
217 // Look to see if we already have this function type.
218 FunctionTypeKeyInfo::KeyTy key(inputs, results);
219 auto existing = impl.functions.insert_as(nullptr, key);
220
221 // If we already have it, return that value.
222 if (!existing.second)
223 return *existing.first;
224
225 // On the first use, we allocate them into the bump pointer.
226 auto *result = impl.allocator.Allocate<FunctionType>();
227
228 // Copy the inputs and results into the bump pointer.
229 SmallVector<Type*, 16> types;
230 types.reserve(inputs.size()+results.size());
231 types.append(inputs.begin(), inputs.end());
232 types.append(results.begin(), results.end());
233 auto typesList = impl.copyInto(ArrayRef<Type*>(types));
234
235 // Initialize the memory using placement new.
236 new (result) FunctionType(typesList.data(), inputs.size(), results.size(),
237 context);
238
239 // Cache and return it.
240 return *existing.first = result;
241}
242
Chris Lattnerf7e22732018-06-22 22:03:48 -0700243VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
244 assert(!shape.empty() && "vector types must have at least one dimension");
Chris Lattnerf958bbe2018-06-29 22:08:05 -0700245 assert((isa<PrimitiveType>(elementType) || isa<IntegerType>(elementType)) &&
Chris Lattnerf7e22732018-06-22 22:03:48 -0700246 "vectors elements must be primitives");
247
248 auto *context = elementType->getContext();
249 auto &impl = context->getImpl();
250
251 // Look to see if we already have this vector type.
252 VectorTypeKeyInfo::KeyTy key(elementType, shape);
253 auto existing = impl.vectors.insert_as(nullptr, key);
254
255 // If we already have it, return that value.
256 if (!existing.second)
257 return *existing.first;
258
259 // On the first use, we allocate them into the bump pointer.
260 auto *result = impl.allocator.Allocate<VectorType>();
261
262 // Copy the shape into the bump pointer.
263 shape = impl.copyInto(shape);
264
265 // Initialize the memory using placement new.
266 new (result) VectorType(shape, cast<PrimitiveType>(elementType), context);
267
268 // Cache and return it.
269 return *existing.first = result;
270}
MLIR Team355ec862018-06-23 18:09:09 -0700271
272
273TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
274 : Type(kind, context), elementType(elementType) {
Chris Lattnerf958bbe2018-06-29 22:08:05 -0700275 assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType) ||
276 isa<IntegerType>(elementType)) &&
MLIR Team355ec862018-06-23 18:09:09 -0700277 "tensor elements must be primitives or vectors");
278 assert(isa<TensorType>(this));
279}
280
MLIR Team355ec862018-06-23 18:09:09 -0700281RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
282 Type *elementType) {
283 auto *context = elementType->getContext();
284 auto &impl = context->getImpl();
285
286 // Look to see if we already have this ranked tensor type.
287 RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
288 auto existing = impl.rankedTensors.insert_as(nullptr, key);
289
290 // If we already have it, return that value.
291 if (!existing.second)
292 return *existing.first;
293
294 // On the first use, we allocate them into the bump pointer.
295 auto *result = impl.allocator.Allocate<RankedTensorType>();
296
297 // Copy the shape into the bump pointer.
298 shape = impl.copyInto(shape);
299
300 // Initialize the memory using placement new.
301 new (result) RankedTensorType(shape, elementType, context);
302
303 // Cache and return it.
304 return *existing.first = result;
305}
306
307UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
308 auto *context = elementType->getContext();
309 auto &impl = context->getImpl();
310
311 // Look to see if we already have this unranked tensor type.
312 auto existing = impl.unrankedTensors.insert({elementType, nullptr});
313
314 // If we already have it, return that value.
315 if (!existing.second)
316 return existing.first->second;
317
318 // On the first use, we allocate them into the bump pointer.
319 auto *result = impl.allocator.Allocate<UnrankedTensorType>();
320
321 // Initialize the memory using placement new.
322 new (result) UnrankedTensorType(elementType, context);
323
324 // Cache and return it.
325 return existing.first->second = result;
326}
Uday Bondhugulafaf37dd2018-06-29 18:09:29 -0700327
328// TODO(bondhugula,andydavis): unique affine maps based on dim list,
329// symbol list and all affine expressions contained
330AffineMap *AffineMap::get(unsigned dimCount,
331 unsigned symbolCount,
332 ArrayRef<AffineExpr *> exprs,
333 MLIRContext *context) {
334 // TODO(bondhugula)
335 return new AffineMap(dimCount, symbolCount, exprs);
336}
337
338AffineBinaryOpExpr *AffineBinaryOpExpr::get(AffineExpr::Kind kind,
339 AffineExpr *lhsOperand,
340 AffineExpr *rhsOperand,
341 MLIRContext *context) {
342 // TODO(bondhugula): allocate this through context
343 // FIXME
344 return new AffineBinaryOpExpr(kind, lhsOperand, rhsOperand);
345}
346
347AffineAddExpr *AffineAddExpr::get(AffineExpr *lhsOperand,
348 AffineExpr *rhsOperand,
349 MLIRContext *context) {
350 // TODO(bondhugula): allocate this through context
351 // FIXME
352 return new AffineAddExpr(lhsOperand, rhsOperand);
353}
354
355// TODO(bondhugula): add functions for AffineMulExpr, mod, floordiv, ceildiv
356
357AffineDimExpr *AffineDimExpr::get(unsigned position, MLIRContext *context) {
358 // TODO(bondhugula): complete this
359 // FIXME: this should be POD
360 return new AffineDimExpr(position);
361}
362
363AffineSymbolExpr *AffineSymbolExpr::get(unsigned position,
364 MLIRContext *context) {
365 // TODO(bondhugula): complete this
366 // FIXME: this should be POD
367 return new AffineSymbolExpr(position);
368}
369
370AffineConstantExpr *AffineConstantExpr::get(int64_t constant,
371 MLIRContext *context) {
372 // TODO(bondhugula): complete this
373 // FIXME: this should be POD
374 return new AffineConstantExpr(constant);
375}