blob: a2befc3909c55a7e725fadbab5ba7efa0254c49e [file] [log] [blame]
Chris Lattnerf7e22732018-06-22 22:03:48 -07001//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17
18#include "mlir/IR/MLIRContext.h"
19#include "mlir/IR/Types.h"
20#include "mlir/Support/LLVM.h"
21#include "llvm/ADT/DenseSet.h"
22#include "llvm/Support/Allocator.h"
23using namespace mlir;
24using namespace llvm;
25
26namespace {
27struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
28 // Functions are uniqued based on their inputs and results.
29 using KeyTy = std::pair<ArrayRef<Type*>, ArrayRef<Type*>>;
30 using DenseMapInfo<FunctionType*>::getHashValue;
31 using DenseMapInfo<FunctionType*>::isEqual;
32
33 static unsigned getHashValue(KeyTy key) {
34 return hash_combine(hash_combine_range(key.first.begin(), key.first.end()),
35 hash_combine_range(key.second.begin(),
36 key.second.end()));
37 }
38
39 static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
40 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
41 return false;
42 return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
43 }
44};
45struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
46 // Vectors are uniqued based on their element type and shape.
47 using KeyTy = std::pair<Type*, ArrayRef<unsigned>>;
48 using DenseMapInfo<VectorType*>::getHashValue;
49 using DenseMapInfo<VectorType*>::isEqual;
50
51 static unsigned getHashValue(KeyTy key) {
52 return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
53 hash_combine_range(key.second.begin(),
54 key.second.end()));
55 }
56
57 static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
58 if (rhs == getEmptyKey() || rhs == getTombstoneKey())
59 return false;
60 return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
61 }
62};
63} // end anonymous namespace.
64
65
66namespace mlir {
67/// This is the implementation of the MLIRContext class, using the pImpl idiom.
68/// This class is completely private to this file, so everything is public.
69class MLIRContextImpl {
70public:
71 /// We put immortal objects into this allocator.
72 llvm::BumpPtrAllocator allocator;
73
74 // Primitive type uniquing.
75 PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
76
77 /// Function type uniquing.
78 using FunctionTypeSet = DenseSet<FunctionType*, FunctionTypeKeyInfo>;
79 FunctionTypeSet functions;
80
81 /// Vector type uniquing.
82 using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
83 VectorTypeSet vectors;
84
85
86public:
87 /// Copy the specified array of elements into memory managed by our bump
88 /// pointer allocator. This assumes the elements are all PODs.
89 template<typename T>
90 ArrayRef<T> copyInto(ArrayRef<T> elements) {
91 auto result = allocator.Allocate<T>(elements.size());
92 std::uninitialized_copy(elements.begin(), elements.end(), result);
93 return ArrayRef<T>(result, elements.size());
94 }
95};
96} // end namespace mlir
97
98MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
99}
100
101MLIRContext::~MLIRContext() {
102}
103
104
105PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
106 : Type(kind, context) {
107
108}
109
110PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
111 assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
112 auto &impl = context->getImpl();
113
114 // We normally have these types.
115 if (impl.primitives[(int)kind])
116 return impl.primitives[(int)kind];
117
118 // On the first use, we allocate them into the bump pointer.
119 auto *ptr = impl.allocator.Allocate<PrimitiveType>();
120
121 // Initialize the memory using placement new.
122 new(ptr) PrimitiveType(kind, context);
123
124 // Cache and return it.
125 return impl.primitives[(int)kind] = ptr;
126}
127
128FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
129 unsigned numResults, MLIRContext *context)
130 : Type(TypeKind::Function, context, numInputs),
131 numResults(numResults), inputsAndResults(inputsAndResults) {
132}
133
134FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
135 MLIRContext *context) {
136 auto &impl = context->getImpl();
137
138 // Look to see if we already have this function type.
139 FunctionTypeKeyInfo::KeyTy key(inputs, results);
140 auto existing = impl.functions.insert_as(nullptr, key);
141
142 // If we already have it, return that value.
143 if (!existing.second)
144 return *existing.first;
145
146 // On the first use, we allocate them into the bump pointer.
147 auto *result = impl.allocator.Allocate<FunctionType>();
148
149 // Copy the inputs and results into the bump pointer.
150 SmallVector<Type*, 16> types;
151 types.reserve(inputs.size()+results.size());
152 types.append(inputs.begin(), inputs.end());
153 types.append(results.begin(), results.end());
154 auto typesList = impl.copyInto(ArrayRef<Type*>(types));
155
156 // Initialize the memory using placement new.
157 new (result) FunctionType(typesList.data(), inputs.size(), results.size(),
158 context);
159
160 // Cache and return it.
161 return *existing.first = result;
162}
163
164
165
166VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
167 MLIRContext *context)
168 : Type(TypeKind::Vector, context, shape.size()),
169 shapeElements(shape.data()), elementType(elementType) {
170}
171
172
173VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
174 assert(!shape.empty() && "vector types must have at least one dimension");
175 assert(isa<PrimitiveType>(elementType) &&
176 "vectors elements must be primitives");
177
178 auto *context = elementType->getContext();
179 auto &impl = context->getImpl();
180
181 // Look to see if we already have this vector type.
182 VectorTypeKeyInfo::KeyTy key(elementType, shape);
183 auto existing = impl.vectors.insert_as(nullptr, key);
184
185 // If we already have it, return that value.
186 if (!existing.second)
187 return *existing.first;
188
189 // On the first use, we allocate them into the bump pointer.
190 auto *result = impl.allocator.Allocate<VectorType>();
191
192 // Copy the shape into the bump pointer.
193 shape = impl.copyInto(shape);
194
195 // Initialize the memory using placement new.
196 new (result) VectorType(shape, cast<PrimitiveType>(elementType), context);
197
198 // Cache and return it.
199 return *existing.first = result;
200}