blob: 9c30660d951dbb1c30960c88d43a95b4bb0bc143 [file] [log] [blame]
Chris Lattnerf7e22732018-06-22 22:03:48 -07001//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLIRContext.h"
19#include "mlir/IR/Types.h"
20#include "mlir/Support/LLVM.h"
21#include "llvm/ADT/DenseSet.h"
22#include "llvm/Support/Allocator.h"
23using namespace mlir;
24using namespace llvm;
25
26namespace {
27struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
28 // Functions are uniqued based on their inputs and results.
29 using KeyTy = std::pair<ArrayRef<Type*>, ArrayRef<Type*>>;
30 using DenseMapInfo<FunctionType*>::getHashValue;
31 using DenseMapInfo<FunctionType*>::isEqual;
32
33 static unsigned getHashValue(KeyTy key) {
34 return hash_combine(hash_combine_range(key.first.begin(), key.first.end()),
35 hash_combine_range(key.second.begin(),
36 key.second.end()));
37 }
38
39 static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
40 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
41 return false;
42 return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
43 }
44};
45struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
46 // Vectors are uniqued based on their element type and shape.
47 using KeyTy = std::pair<Type*, ArrayRef<unsigned>>;
48 using DenseMapInfo<VectorType*>::getHashValue;
49 using DenseMapInfo<VectorType*>::isEqual;
50
51 static unsigned getHashValue(KeyTy key) {
52 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
53 hash_combine_range(key.second.begin(),
54 key.second.end()));
55 }
56
57 static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
58 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
59 return false;
60 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
61 }
62};
MLIR Team355ec862018-06-23 18:09:09 -070063struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
64 // Ranked tensors are uniqued based on their element type and shape.
65 using KeyTy = std::pair<Type*, ArrayRef<int>>;
66 using DenseMapInfo<RankedTensorType*>::getHashValue;
67 using DenseMapInfo<RankedTensorType*>::isEqual;
68
69 static unsigned getHashValue(KeyTy key) {
70 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
71 hash_combine_range(key.second.begin(),
72 key.second.end()));
73 }
74
75 static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
76 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
77 return false;
78 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
79 }
80};
81struct UnrankedTensorTypeKeyInfo : DenseMapInfo<UnrankedTensorType*> {
82 // Ranked tensors are uniqued based on their element type and shape.
83 using KeyTy = Type*;
84 using DenseMapInfo<UnrankedTensorType*>::getHashValue;
85 using DenseMapInfo<UnrankedTensorType*>::isEqual;
86
87 static unsigned getHashValue(KeyTy key) {
88 return hash_combine(DenseMapInfo<Type*>::getHashValue(key));
89 }
90
91 static bool isEqual(const KeyTy &lhs, const UnrankedTensorType *rhs) {
92 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
93 return false;
94 return lhs == rhs->getElementType();
95 }
96};
Chris Lattnerf7e22732018-06-22 22:03:48 -070097} // end anonymous namespace.
98
99
100namespace mlir {
101/// This is the implementation of the MLIRContext class, using the pImpl idiom.
102/// This class is completely private to this file, so everything is public.
103class MLIRContextImpl {
104public:
105 /// We put immortal objects into this allocator.
106 llvm::BumpPtrAllocator allocator;
107
108 // Primitive type uniquing.
109 PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
110
111 /// Function type uniquing.
112 using FunctionTypeSet = DenseSet<FunctionType*, FunctionTypeKeyInfo>;
113 FunctionTypeSet functions;
114
115 /// Vector type uniquing.
116 using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
117 VectorTypeSet vectors;
118
MLIR Team355ec862018-06-23 18:09:09 -0700119 /// Ranked tensor type uniquing.
120 using RankedTensorTypeSet = DenseSet<RankedTensorType*,
121 RankedTensorTypeKeyInfo>;
122 RankedTensorTypeSet rankedTensors;
123
124 /// Unranked tensor type uniquing.
125 DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
126
Chris Lattnerf7e22732018-06-22 22:03:48 -0700127
128public:
129 /// Copy the specified array of elements into memory managed by our bump
130 /// pointer allocator. This assumes the elements are all PODs.
131 template<typename T>
132 ArrayRef<T> copyInto(ArrayRef<T> elements) {
133 auto result = allocator.Allocate<T>(elements.size());
134 std::uninitialized_copy(elements.begin(), elements.end(), result);
135 return ArrayRef<T>(result, elements.size());
136 }
137};
138} // end namespace mlir
139
140MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
141}
142
143MLIRContext::~MLIRContext() {
144}
145
146
147PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
148 : Type(kind, context) {
149
150}
151
152PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
153 assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
154 auto &impl = context->getImpl();
155
156 // We normally have these types.
157 if (impl.primitives[(int)kind])
158 return impl.primitives[(int)kind];
159
160 // On the first use, we allocate them into the bump pointer.
161 auto *ptr = impl.allocator.Allocate<PrimitiveType>();
162
163 // Initialize the memory using placement new.
164 new(ptr) PrimitiveType(kind, context);
165
166 // Cache and return it.
167 return impl.primitives[(int)kind] = ptr;
168}
169
170FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
171 unsigned numResults, MLIRContext *context)
172 : Type(TypeKind::Function, context, numInputs),
173 numResults(numResults), inputsAndResults(inputsAndResults) {
174}
175
176FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
177 MLIRContext *context) {
178 auto &impl = context->getImpl();
179
180 // Look to see if we already have this function type.
181 FunctionTypeKeyInfo::KeyTy key(inputs, results);
182 auto existing = impl.functions.insert_as(nullptr, key);
183
184 // If we already have it, return that value.
185 if (!existing.second)
186 return *existing.first;
187
188 // On the first use, we allocate them into the bump pointer.
189 auto *result = impl.allocator.Allocate<FunctionType>();
190
191 // Copy the inputs and results into the bump pointer.
192 SmallVector<Type*, 16> types;
193 types.reserve(inputs.size()+results.size());
194 types.append(inputs.begin(), inputs.end());
195 types.append(results.begin(), results.end());
196 auto typesList = impl.copyInto(ArrayRef<Type*>(types));
197
198 // Initialize the memory using placement new.
199 new (result) FunctionType(typesList.data(), inputs.size(), results.size(),
200 context);
201
202 // Cache and return it.
203 return *existing.first = result;
204}
205
206
207
208VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
209 MLIRContext *context)
210 : Type(TypeKind::Vector, context, shape.size()),
211 shapeElements(shape.data()), elementType(elementType) {
212}
213
214
215VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
216 assert(!shape.empty() && "vector types must have at least one dimension");
217 assert(isa<PrimitiveType>(elementType) &&
218 "vectors elements must be primitives");
219
220 auto *context = elementType->getContext();
221 auto &impl = context->getImpl();
222
223 // Look to see if we already have this vector type.
224 VectorTypeKeyInfo::KeyTy key(elementType, shape);
225 auto existing = impl.vectors.insert_as(nullptr, key);
226
227 // If we already have it, return that value.
228 if (!existing.second)
229 return *existing.first;
230
231 // On the first use, we allocate them into the bump pointer.
232 auto *result = impl.allocator.Allocate<VectorType>();
233
234 // Copy the shape into the bump pointer.
235 shape = impl.copyInto(shape);
236
237 // Initialize the memory using placement new.
238 new (result) VectorType(shape, cast<PrimitiveType>(elementType), context);
239
240 // Cache and return it.
241 return *existing.first = result;
242}
MLIR Team355ec862018-06-23 18:09:09 -0700243
244
245TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
246 : Type(kind, context), elementType(elementType) {
247 assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
248 "tensor elements must be primitives or vectors");
249 assert(isa<TensorType>(this));
250}
251
252RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
253 MLIRContext *context)
254 : TensorType(TypeKind::RankedTensor, elementType, context),
255 shapeElements(shape.data()) {
256 setSubclassData(shape.size());
257}
258
259UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
260 : TensorType(TypeKind::UnrankedTensor, elementType, context) {
261}
262
263RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
264 Type *elementType) {
265 auto *context = elementType->getContext();
266 auto &impl = context->getImpl();
267
268 // Look to see if we already have this ranked tensor type.
269 RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
270 auto existing = impl.rankedTensors.insert_as(nullptr, key);
271
272 // If we already have it, return that value.
273 if (!existing.second)
274 return *existing.first;
275
276 // On the first use, we allocate them into the bump pointer.
277 auto *result = impl.allocator.Allocate<RankedTensorType>();
278
279 // Copy the shape into the bump pointer.
280 shape = impl.copyInto(shape);
281
282 // Initialize the memory using placement new.
283 new (result) RankedTensorType(shape, elementType, context);
284
285 // Cache and return it.
286 return *existing.first = result;
287}
288
289UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
290 auto *context = elementType->getContext();
291 auto &impl = context->getImpl();
292
293 // Look to see if we already have this unranked tensor type.
294 auto existing = impl.unrankedTensors.insert({elementType, nullptr});
295
296 // If we already have it, return that value.
297 if (!existing.second)
298 return existing.first->second;
299
300 // On the first use, we allocate them into the bump pointer.
301 auto *result = impl.allocator.Allocate<UnrankedTensorType>();
302
303 // Initialize the memory using placement new.
304 new (result) UnrankedTensorType(elementType, context);
305
306 // Cache and return it.
307 return existing.first->second = result;
308}