Implement IR support for attributes.
PiperOrigin-RevId: 203293376
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
new file mode 100644
index 0000000..0a773f1
--- /dev/null
+++ b/lib/IR/Attributes.cpp
@@ -0,0 +1,57 @@
+//===- Attributes.cpp - MLIR Attribute Implementation ---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Support/STLExtras.h"
+using namespace mlir;
+
+void Attribute::print(raw_ostream &os) const {
+ switch (getKind()) {
+ case Kind::Bool:
+ os << (cast<BoolAttr>(this)->getValue() ? "true" : "false");
+ break;
+ case Kind::Integer:
+ os << cast<IntegerAttr>(this)->getValue();
+ break;
+ case Kind::Float:
+ // FIXME: this isn't precise, we should print with a hex format.
+ os << cast<FloatAttr>(this)->getValue();
+ break;
+ case Kind::String:
+ // FIXME: should escape the string.
+ os << '"' << cast<StringAttr>(this)->getValue() << '"';
+ break;
+ case Kind::Array: {
+ auto elts = cast<ArrayAttr>(this)->getValue();
+ if (elts.empty())
+ os << "[]";
+ else {
+ os << "[ ";
+ interleave(elts,
+ [&](Attribute *attr) { attr->print(os); },
+ [&]() { os << ", "; });
+ os << " ]";
+ }
+ break;
+ }
+ }
+}
+
+void Attribute::dump() const {
+ print(llvm::errs());
+}
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index d148523..3524909 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -16,11 +16,12 @@
// =============================================================================
#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Identifier.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
#include "mlir/IR/Types.h"
-#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Allocator.h"
@@ -50,23 +51,21 @@
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
// Affine maps are uniqued based on their dim/symbol counts and affine
// expressions.
- using KeyTy =
- std::pair<std::pair<unsigned, unsigned>, ArrayRef<AffineExpr *>>;
+ using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>>;
using DenseMapInfo<AffineMap *>::getHashValue;
using DenseMapInfo<AffineMap *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
- key.first.first, key.first.second,
- hash_combine_range(key.second.begin(), key.second.end()));
+ std::get<0>(key), std::get<1>(key),
+ hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
}
static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
- return lhs == KeyTy(std::pair<unsigned, unsigned>(rhs->getNumDims(),
- rhs->getNumSymbols()),
- rhs->getResults());
+ return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
+ rhs->getResults());
}
};
@@ -88,6 +87,7 @@
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
}
};
+
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
// Ranked tensors are uniqued based on their element type and shape.
using KeyTy = std::pair<Type*, ArrayRef<int>>;
@@ -106,6 +106,23 @@
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
}
};
+
+struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr*> {
+ // Array attributes are uniqued based on their elements.
+ using KeyTy = ArrayRef<Attribute*>;
+ using DenseMapInfo<ArrayAttr*>::getHashValue;
+ using DenseMapInfo<ArrayAttr*>::isEqual;
+
+ static unsigned getHashValue(KeyTy key) {
+ return hash_combine_range(key.begin(), key.end());
+ }
+
+ static bool isEqual(const KeyTy &lhs, const ArrayAttr *rhs) {
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs == rhs->getValue();
+ }
+};
} // end anonymous namespace.
@@ -129,8 +146,7 @@
// Affine binary op expression uniquing. Figure out uniquing of dimensional
// or symbolic identifiers.
- // std::tuple doesn't work with DenseMap!, using nested pair.
- DenseMap<std::pair<unsigned, std::pair<AffineExpr *, AffineExpr *>>,
+ DenseMap<std::tuple<unsigned, AffineExpr *, AffineExpr *>,
AffineBinaryOpExpr *>
affineExprs;
@@ -153,6 +169,13 @@
/// Unranked tensor type uniquing.
DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
+ // Attribute uniquing.
+ BoolAttr *boolAttrs[2] = { nullptr };
+ DenseMap<int64_t, IntegerAttr*> integerAttrs;
+ DenseMap<int64_t, FloatAttr*> floatAttrs;
+ StringMap<StringAttr*> stringAttrs;
+ using ArrayAttrSet = DenseSet<ArrayAttr*, ArrayAttrKeyInfo>;
+ ArrayAttrSet arrayAttrs;
public:
MLIRContextImpl() : identifiers(allocator) {}
@@ -176,7 +199,7 @@
//===----------------------------------------------------------------------===//
-// Identifier
+// Identifier uniquing
//===----------------------------------------------------------------------===//
/// Return an identifier for the specified string.
@@ -191,7 +214,7 @@
}
//===----------------------------------------------------------------------===//
-// Types
+// Type uniquing
//===----------------------------------------------------------------------===//
PrimitiveType *PrimitiveType::get(Kind kind, MLIRContext *context) {
@@ -323,22 +346,101 @@
auto &impl = context->getImpl();
// Look to see if we already have this unranked tensor type.
- auto existing = impl.unrankedTensors.insert({elementType, nullptr});
+ auto *&result = impl.unrankedTensors[elementType];
// If we already have it, return that value.
- if (!existing.second)
- return existing.first->second;
+ if (result)
+ return result;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<UnrankedTensorType>();
+ result = impl.allocator.Allocate<UnrankedTensorType>();
// Initialize the memory using placement new.
new (result) UnrankedTensorType(elementType, context);
+ return result;
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute uniquing
+//===----------------------------------------------------------------------===//
+
+BoolAttr *BoolAttr::get(bool value, MLIRContext *context) {
+ auto *&result = context->getImpl().boolAttrs[value];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<BoolAttr>();
+ new (result) BoolAttr(value);
+ return result;
+}
+
+IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
+ auto *&result = context->getImpl().integerAttrs[value];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<IntegerAttr>();
+ new (result) IntegerAttr(value);
+ return result;
+}
+
+FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
+ // We hash based on the bit representation of the double to ensure we don't
+ // merge things like -0.0 and 0.0 in the hash comparison.
+ union {
+ double floatValue;
+ int64_t intValue;
+ };
+ floatValue = value;
+
+ auto *&result = context->getImpl().floatAttrs[intValue];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<FloatAttr>();
+ new (result) FloatAttr(value);
+ return result;
+}
+
+StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
+ auto it = context->getImpl().stringAttrs.insert({bytes, nullptr}).first;
+
+ if (it->second)
+ return it->second;
+
+ auto result = context->getImpl().allocator.Allocate<StringAttr>();
+ new (result) StringAttr(it->first());
+ it->second = result;
+ return result;
+}
+
+ArrayAttr *ArrayAttr::get(ArrayRef<Attribute*> value, MLIRContext *context) {
+ auto &impl = context->getImpl();
+
+ // Look to see if we already have this.
+ auto existing = impl.arrayAttrs.insert_as(nullptr, value);
+
+ // If we already have it, return that value.
+ if (!existing.second)
+ return *existing.first;
+
+ // On the first use, we allocate them into the bump pointer.
+ auto *result = impl.allocator.Allocate<ArrayAttr>();
+
+ // Copy the elements into the bump pointer.
+ value = impl.copyInto(value);
+
+ // Initialize the memory using placement new.
+ new (result) ArrayAttr(value);
// Cache and return it.
- return existing.first->second = result;
+ return *existing.first = result;
}
+//===----------------------------------------------------------------------===//
+// AffineMap and AffineExpr uniquing
+//===----------------------------------------------------------------------===//
+
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results,
MLIRContext *context) {
@@ -348,8 +450,7 @@
auto &impl = context->getImpl();
// Check if we already have this affine map.
- AffineMapKeyInfo::KeyTy key(
- std::pair<unsigned, unsigned>(dimCount, symbolCount), results);
+ auto key = std::make_tuple(dimCount, symbolCount, results);
auto existing = impl.affineMaps.insert_as(nullptr, key);
// If we already have it, return that value.
@@ -369,7 +470,6 @@
return *existing.first = res;
}
-// TODO(bondhugula): complete uniqu'ing of remaining AffinExpr sub-classes
AffineBinaryOpExpr *AffineBinaryOpExpr::get(AffineExpr::Kind kind,
AffineExpr *lhsOperand,
AffineExpr *rhsOperand,
@@ -377,10 +477,8 @@
auto &impl = context->getImpl();
// Check if we already have this affine expression.
- auto key = std::pair<unsigned, std::pair<AffineExpr *, AffineExpr *>>(
- (unsigned)kind,
- std::pair<AffineExpr *, AffineExpr *>(lhsOperand, rhsOperand));
- auto *&result = impl.affineExprs[key];
+ auto keyValue = std::make_tuple((unsigned)kind, lhsOperand, rhsOperand);
+ auto *&result = impl.affineExprs[keyValue];
// If we already have it, return that value.
if (!result) {
@@ -393,6 +491,7 @@
return result;
}
+// TODO(bondhugula): complete uniquing of remaining AffineExpr sub-classes.
AffineAddExpr *AffineAddExpr::get(AffineExpr *lhsOperand,
AffineExpr *rhsOperand,
MLIRContext *context) {