blob: 8a035b6e681f4d9a40ee0255de0543e3be652198 [file] [log] [blame]
Chris Lattnerf7e22732018-06-22 22:03:48 -07001//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLIRContext.h"
19#include "mlir/IR/Types.h"
20#include "mlir/Support/LLVM.h"
21#include "llvm/ADT/DenseSet.h"
22#include "llvm/Support/Allocator.h"
23using namespace mlir;
24using namespace llvm;
25
26namespace {
27struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
28 // Functions are uniqued based on their inputs and results.
29 using KeyTy = std::pair<ArrayRef<Type*>, ArrayRef<Type*>>;
30 using DenseMapInfo<FunctionType*>::getHashValue;
31 using DenseMapInfo<FunctionType*>::isEqual;
32
33 static unsigned getHashValue(KeyTy key) {
34 return hash_combine(hash_combine_range(key.first.begin(), key.first.end()),
35 hash_combine_range(key.second.begin(),
36 key.second.end()));
37 }
38
39 static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
40 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
41 return false;
42 return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
43 }
44};
45struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
46 // Vectors are uniqued based on their element type and shape.
47 using KeyTy = std::pair<Type*, ArrayRef<unsigned>>;
48 using DenseMapInfo<VectorType*>::getHashValue;
49 using DenseMapInfo<VectorType*>::isEqual;
50
51 static unsigned getHashValue(KeyTy key) {
52 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
53 hash_combine_range(key.second.begin(),
54 key.second.end()));
55 }
56
57 static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
58 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
59 return false;
60 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
61 }
62};
MLIR Team355ec862018-06-23 18:09:09 -070063struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
64 // Ranked tensors are uniqued based on their element type and shape.
65 using KeyTy = std::pair<Type*, ArrayRef<int>>;
66 using DenseMapInfo<RankedTensorType*>::getHashValue;
67 using DenseMapInfo<RankedTensorType*>::isEqual;
68
69 static unsigned getHashValue(KeyTy key) {
70 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
71 hash_combine_range(key.second.begin(),
72 key.second.end()));
73 }
74
75 static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
76 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
77 return false;
78 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
79 }
80};
Chris Lattnerf7e22732018-06-22 22:03:48 -070081} // end anonymous namespace.
82
83
84namespace mlir {
85/// This is the implementation of the MLIRContext class, using the pImpl idiom.
86/// This class is completely private to this file, so everything is public.
87class MLIRContextImpl {
88public:
89 /// We put immortal objects into this allocator.
90 llvm::BumpPtrAllocator allocator;
91
92 // Primitive type uniquing.
93 PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
94
95 /// Function type uniquing.
96 using FunctionTypeSet = DenseSet<FunctionType*, FunctionTypeKeyInfo>;
97 FunctionTypeSet functions;
98
99 /// Vector type uniquing.
100 using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
101 VectorTypeSet vectors;
102
MLIR Team355ec862018-06-23 18:09:09 -0700103 /// Ranked tensor type uniquing.
104 using RankedTensorTypeSet = DenseSet<RankedTensorType*,
105 RankedTensorTypeKeyInfo>;
106 RankedTensorTypeSet rankedTensors;
107
108 /// Unranked tensor type uniquing.
109 DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
110
Chris Lattnerf7e22732018-06-22 22:03:48 -0700111
112public:
113 /// Copy the specified array of elements into memory managed by our bump
114 /// pointer allocator. This assumes the elements are all PODs.
115 template<typename T>
116 ArrayRef<T> copyInto(ArrayRef<T> elements) {
117 auto result = allocator.Allocate<T>(elements.size());
118 std::uninitialized_copy(elements.begin(), elements.end(), result);
119 return ArrayRef<T>(result, elements.size());
120 }
121};
122} // end namespace mlir
123
124MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
125}
126
127MLIRContext::~MLIRContext() {
128}
129
130
131PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
132 : Type(kind, context) {
133
134}
135
136PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
137 assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
138 auto &impl = context->getImpl();
139
140 // We normally have these types.
141 if (impl.primitives[(int)kind])
142 return impl.primitives[(int)kind];
143
144 // On the first use, we allocate them into the bump pointer.
145 auto *ptr = impl.allocator.Allocate<PrimitiveType>();
146
147 // Initialize the memory using placement new.
148 new(ptr) PrimitiveType(kind, context);
149
150 // Cache and return it.
151 return impl.primitives[(int)kind] = ptr;
152}
153
154FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
155 unsigned numResults, MLIRContext *context)
156 : Type(TypeKind::Function, context, numInputs),
157 numResults(numResults), inputsAndResults(inputsAndResults) {
158}
159
160FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
161 MLIRContext *context) {
162 auto &impl = context->getImpl();
163
164 // Look to see if we already have this function type.
165 FunctionTypeKeyInfo::KeyTy key(inputs, results);
166 auto existing = impl.functions.insert_as(nullptr, key);
167
168 // If we already have it, return that value.
169 if (!existing.second)
170 return *existing.first;
171
172 // On the first use, we allocate them into the bump pointer.
173 auto *result = impl.allocator.Allocate<FunctionType>();
174
175 // Copy the inputs and results into the bump pointer.
176 SmallVector<Type*, 16> types;
177 types.reserve(inputs.size()+results.size());
178 types.append(inputs.begin(), inputs.end());
179 types.append(results.begin(), results.end());
180 auto typesList = impl.copyInto(ArrayRef<Type*>(types));
181
182 // Initialize the memory using placement new.
183 new (result) FunctionType(typesList.data(), inputs.size(), results.size(),
184 context);
185
186 // Cache and return it.
187 return *existing.first = result;
188}
189
190
191
192VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
193 MLIRContext *context)
194 : Type(TypeKind::Vector, context, shape.size()),
195 shapeElements(shape.data()), elementType(elementType) {
196}
197
198
199VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
200 assert(!shape.empty() && "vector types must have at least one dimension");
201 assert(isa<PrimitiveType>(elementType) &&
202 "vectors elements must be primitives");
203
204 auto *context = elementType->getContext();
205 auto &impl = context->getImpl();
206
207 // Look to see if we already have this vector type.
208 VectorTypeKeyInfo::KeyTy key(elementType, shape);
209 auto existing = impl.vectors.insert_as(nullptr, key);
210
211 // If we already have it, return that value.
212 if (!existing.second)
213 return *existing.first;
214
215 // On the first use, we allocate them into the bump pointer.
216 auto *result = impl.allocator.Allocate<VectorType>();
217
218 // Copy the shape into the bump pointer.
219 shape = impl.copyInto(shape);
220
221 // Initialize the memory using placement new.
222 new (result) VectorType(shape, cast<PrimitiveType>(elementType), context);
223
224 // Cache and return it.
225 return *existing.first = result;
226}
MLIR Team355ec862018-06-23 18:09:09 -0700227
228
229TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
230 : Type(kind, context), elementType(elementType) {
231 assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
232 "tensor elements must be primitives or vectors");
233 assert(isa<TensorType>(this));
234}
235
236RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
237 MLIRContext *context)
238 : TensorType(TypeKind::RankedTensor, elementType, context),
239 shapeElements(shape.data()) {
240 setSubclassData(shape.size());
241}
242
243UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
244 : TensorType(TypeKind::UnrankedTensor, elementType, context) {
245}
246
247RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
248 Type *elementType) {
249 auto *context = elementType->getContext();
250 auto &impl = context->getImpl();
251
252 // Look to see if we already have this ranked tensor type.
253 RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
254 auto existing = impl.rankedTensors.insert_as(nullptr, key);
255
256 // If we already have it, return that value.
257 if (!existing.second)
258 return *existing.first;
259
260 // On the first use, we allocate them into the bump pointer.
261 auto *result = impl.allocator.Allocate<RankedTensorType>();
262
263 // Copy the shape into the bump pointer.
264 shape = impl.copyInto(shape);
265
266 // Initialize the memory using placement new.
267 new (result) RankedTensorType(shape, elementType, context);
268
269 // Cache and return it.
270 return *existing.first = result;
271}
272
273UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
274 auto *context = elementType->getContext();
275 auto &impl = context->getImpl();
276
277 // Look to see if we already have this unranked tensor type.
278 auto existing = impl.unrankedTensors.insert({elementType, nullptr});
279
280 // If we already have it, return that value.
281 if (!existing.second)
282 return existing.first->second;
283
284 // On the first use, we allocate them into the bump pointer.
285 auto *result = impl.allocator.Allocate<UnrankedTensorType>();
286
287 // Initialize the memory using placement new.
288 new (result) UnrankedTensorType(elementType, context);
289
290 // Cache and return it.
291 return existing.first->second = result;
292}