blob: 5f2bd8ec764058a53355f7057aece1519a28bf1a [file] [log] [blame]
Chris Lattnerf7e22732018-06-22 22:03:48 -07001//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLIRContext.h"
Chris Lattnered65a732018-06-28 20:45:33 -070019#include "mlir/IR/Identifier.h"
Chris Lattnerf7e22732018-06-22 22:03:48 -070020#include "mlir/IR/Types.h"
21#include "mlir/Support/LLVM.h"
22#include "llvm/ADT/DenseSet.h"
Chris Lattnered65a732018-06-28 20:45:33 -070023#include "llvm/ADT/StringMap.h"
Chris Lattnerf7e22732018-06-22 22:03:48 -070024#include "llvm/Support/Allocator.h"
25using namespace mlir;
26using namespace llvm;
27
28namespace {
29struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
30 // Functions are uniqued based on their inputs and results.
31 using KeyTy = std::pair<ArrayRef<Type*>, ArrayRef<Type*>>;
32 using DenseMapInfo<FunctionType*>::getHashValue;
33 using DenseMapInfo<FunctionType*>::isEqual;
34
35 static unsigned getHashValue(KeyTy key) {
36 return hash_combine(hash_combine_range(key.first.begin(), key.first.end()),
37 hash_combine_range(key.second.begin(),
38 key.second.end()));
39 }
40
41 static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
42 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
43 return false;
44 return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
45 }
46};
47struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
48 // Vectors are uniqued based on their element type and shape.
49 using KeyTy = std::pair<Type*, ArrayRef<unsigned>>;
50 using DenseMapInfo<VectorType*>::getHashValue;
51 using DenseMapInfo<VectorType*>::isEqual;
52
53 static unsigned getHashValue(KeyTy key) {
54 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
55 hash_combine_range(key.second.begin(),
56 key.second.end()));
57 }
58
59 static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
60 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
61 return false;
62 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
63 }
64};
MLIR Team355ec862018-06-23 18:09:09 -070065struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
66 // Ranked tensors are uniqued based on their element type and shape.
67 using KeyTy = std::pair<Type*, ArrayRef<int>>;
68 using DenseMapInfo<RankedTensorType*>::getHashValue;
69 using DenseMapInfo<RankedTensorType*>::isEqual;
70
71 static unsigned getHashValue(KeyTy key) {
72 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
73 hash_combine_range(key.second.begin(),
74 key.second.end()));
75 }
76
77 static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
78 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
79 return false;
80 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
81 }
82};
Chris Lattnerf7e22732018-06-22 22:03:48 -070083} // end anonymous namespace.
84
85
86namespace mlir {
87/// This is the implementation of the MLIRContext class, using the pImpl idiom.
88/// This class is completely private to this file, so everything is public.
89class MLIRContextImpl {
90public:
91 /// We put immortal objects into this allocator.
92 llvm::BumpPtrAllocator allocator;
93
Chris Lattnered65a732018-06-28 20:45:33 -070094 /// These are identifiers uniqued into this MLIRContext.
95 llvm::StringMap<char, llvm::BumpPtrAllocator&> identifiers;
96
Chris Lattnerf7e22732018-06-22 22:03:48 -070097 // Primitive type uniquing.
98 PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
99
100 /// Function type uniquing.
101 using FunctionTypeSet = DenseSet<FunctionType*, FunctionTypeKeyInfo>;
102 FunctionTypeSet functions;
103
104 /// Vector type uniquing.
105 using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
106 VectorTypeSet vectors;
107
MLIR Team355ec862018-06-23 18:09:09 -0700108 /// Ranked tensor type uniquing.
109 using RankedTensorTypeSet = DenseSet<RankedTensorType*,
110 RankedTensorTypeKeyInfo>;
111 RankedTensorTypeSet rankedTensors;
112
113 /// Unranked tensor type uniquing.
114 DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
115
Chris Lattnerf7e22732018-06-22 22:03:48 -0700116
117public:
Chris Lattnered65a732018-06-28 20:45:33 -0700118 MLIRContextImpl() : identifiers(allocator) {}
119
Chris Lattnerf7e22732018-06-22 22:03:48 -0700120 /// Copy the specified array of elements into memory managed by our bump
121 /// pointer allocator. This assumes the elements are all PODs.
122 template<typename T>
123 ArrayRef<T> copyInto(ArrayRef<T> elements) {
124 auto result = allocator.Allocate<T>(elements.size());
125 std::uninitialized_copy(elements.begin(), elements.end(), result);
126 return ArrayRef<T>(result, elements.size());
127 }
128};
129} // end namespace mlir
130
131MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
132}
133
134MLIRContext::~MLIRContext() {
135}
136
137
Chris Lattnered65a732018-06-28 20:45:33 -0700138//===----------------------------------------------------------------------===//
139// Identifier
140//===----------------------------------------------------------------------===//
141
142/// Return an identifier for the specified string.
143Identifier Identifier::get(StringRef str, const MLIRContext *context) {
144 assert(!str.empty() && "Cannot create an empty identifier");
145 assert(str.find('\0') == StringRef::npos &&
146 "Cannot create an identifier with a nul character");
147
148 auto &impl = context->getImpl();
149 auto it = impl.identifiers.insert({str, char()}).first;
150 return Identifier(it->getKeyData());
151}
152
153
154//===----------------------------------------------------------------------===//
155// Types
156//===----------------------------------------------------------------------===//
157
Chris Lattnerf7e22732018-06-22 22:03:48 -0700158PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
159 : Type(kind, context) {
Chris Lattnerf7e22732018-06-22 22:03:48 -0700160}
161
162PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
163 assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
164 auto &impl = context->getImpl();
165
166 // We normally have these types.
167 if (impl.primitives[(int)kind])
168 return impl.primitives[(int)kind];
169
170 // On the first use, we allocate them into the bump pointer.
171 auto *ptr = impl.allocator.Allocate<PrimitiveType>();
172
173 // Initialize the memory using placement new.
174 new(ptr) PrimitiveType(kind, context);
175
176 // Cache and return it.
177 return impl.primitives[(int)kind] = ptr;
178}
179
180FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
181 unsigned numResults, MLIRContext *context)
182 : Type(TypeKind::Function, context, numInputs),
183 numResults(numResults), inputsAndResults(inputsAndResults) {
184}
185
186FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
187 MLIRContext *context) {
188 auto &impl = context->getImpl();
189
190 // Look to see if we already have this function type.
191 FunctionTypeKeyInfo::KeyTy key(inputs, results);
192 auto existing = impl.functions.insert_as(nullptr, key);
193
194 // If we already have it, return that value.
195 if (!existing.second)
196 return *existing.first;
197
198 // On the first use, we allocate them into the bump pointer.
199 auto *result = impl.allocator.Allocate<FunctionType>();
200
201 // Copy the inputs and results into the bump pointer.
202 SmallVector<Type*, 16> types;
203 types.reserve(inputs.size()+results.size());
204 types.append(inputs.begin(), inputs.end());
205 types.append(results.begin(), results.end());
206 auto typesList = impl.copyInto(ArrayRef<Type*>(types));
207
208 // Initialize the memory using placement new.
209 new (result) FunctionType(typesList.data(), inputs.size(), results.size(),
210 context);
211
212 // Cache and return it.
213 return *existing.first = result;
214}
215
216
217
218VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
219 MLIRContext *context)
220 : Type(TypeKind::Vector, context, shape.size()),
221 shapeElements(shape.data()), elementType(elementType) {
222}
223
224
225VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
226 assert(!shape.empty() && "vector types must have at least one dimension");
227 assert(isa<PrimitiveType>(elementType) &&
228 "vectors elements must be primitives");
229
230 auto *context = elementType->getContext();
231 auto &impl = context->getImpl();
232
233 // Look to see if we already have this vector type.
234 VectorTypeKeyInfo::KeyTy key(elementType, shape);
235 auto existing = impl.vectors.insert_as(nullptr, key);
236
237 // If we already have it, return that value.
238 if (!existing.second)
239 return *existing.first;
240
241 // On the first use, we allocate them into the bump pointer.
242 auto *result = impl.allocator.Allocate<VectorType>();
243
244 // Copy the shape into the bump pointer.
245 shape = impl.copyInto(shape);
246
247 // Initialize the memory using placement new.
248 new (result) VectorType(shape, cast<PrimitiveType>(elementType), context);
249
250 // Cache and return it.
251 return *existing.first = result;
252}
MLIR Team355ec862018-06-23 18:09:09 -0700253
254
255TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
256 : Type(kind, context), elementType(elementType) {
257 assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
258 "tensor elements must be primitives or vectors");
259 assert(isa<TensorType>(this));
260}
261
262RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
263 MLIRContext *context)
264 : TensorType(TypeKind::RankedTensor, elementType, context),
265 shapeElements(shape.data()) {
266 setSubclassData(shape.size());
267}
268
269UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
270 : TensorType(TypeKind::UnrankedTensor, elementType, context) {
271}
272
273RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
274 Type *elementType) {
275 auto *context = elementType->getContext();
276 auto &impl = context->getImpl();
277
278 // Look to see if we already have this ranked tensor type.
279 RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
280 auto existing = impl.rankedTensors.insert_as(nullptr, key);
281
282 // If we already have it, return that value.
283 if (!existing.second)
284 return *existing.first;
285
286 // On the first use, we allocate them into the bump pointer.
287 auto *result = impl.allocator.Allocate<RankedTensorType>();
288
289 // Copy the shape into the bump pointer.
290 shape = impl.copyInto(shape);
291
292 // Initialize the memory using placement new.
293 new (result) RankedTensorType(shape, elementType, context);
294
295 // Cache and return it.
296 return *existing.first = result;
297}
298
299UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
300 auto *context = elementType->getContext();
301 auto &impl = context->getImpl();
302
303 // Look to see if we already have this unranked tensor type.
304 auto existing = impl.unrankedTensors.insert({elementType, nullptr});
305
306 // If we already have it, return that value.
307 if (!existing.second)
308 return existing.first->second;
309
310 // On the first use, we allocate them into the bump pointer.
311 auto *result = impl.allocator.Allocate<UnrankedTensorType>();
312
313 // Initialize the memory using placement new.
314 new (result) UnrankedTensorType(elementType, context);
315
316 // Cache and return it.
317 return existing.first->second = result;
318}