Implement IR support for attributes.
PiperOrigin-RevId: 203293376
diff --git a/include/mlir/IR/Attributes.h b/include/mlir/IR/Attributes.h
new file mode 100644
index 0000000..9b0f580
--- /dev/null
+++ b/include/mlir/IR/Attributes.h
@@ -0,0 +1,153 @@
+//===- Attributes.h - MLIR Attribute Classes --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_IR_ATTRIBUTES_H
+#define MLIR_IR_ATTRIBUTES_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
+
+namespace mlir {
+ class MLIRContext;
+
+/// Instances of the Attribute class are immutable, uniqued, immortal, and owned
+/// by MLIRContext. As such, they are passed around by raw non-const pointer.
+class Attribute {
+public:
+ enum class Kind {
+ Bool,
+ Integer,
+ Float,
+ String,
+ Array,
+ // TODO: Function references.
+ };
+
+ /// Return the classification for this attribute.
+ Kind getKind() const {
+ return kind;
+ }
+
+ /// Print the attribute.
+ void print(raw_ostream &os) const;
+ void dump() const;
+
+protected:
+ explicit Attribute(Kind kind) : kind(kind) {}
+
+private:
+ /// Classification of the subclass, used for type checking.
+ Kind kind : 8;
+
+ Attribute(const Attribute&) = delete;
+ void operator=(const Attribute&) = delete;
+};
+
+inline raw_ostream &operator<<(raw_ostream &os, const Attribute &attr) {
+ attr.print(os);
+ return os;
+}
+
+class BoolAttr : public Attribute {
+public:
+ static BoolAttr *get(bool value, MLIRContext *context);
+
+ bool getValue() const {
+ return value;
+ }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const Attribute *attr) {
+ return attr->getKind() == Kind::Bool;
+ }
+private:
+ BoolAttr(bool value) : Attribute(Kind::Bool), value(value) {}
+ bool value;
+};
+
+class IntegerAttr : public Attribute {
+public:
+ static IntegerAttr *get(int64_t value, MLIRContext *context);
+
+ unsigned getValue() const {
+ return value;
+ }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const Attribute *attr) {
+ return attr->getKind() == Kind::Integer;
+ }
+private:
+ IntegerAttr(int64_t value) : Attribute(Kind::Integer), value(value) {}
+ int64_t value;
+};
+
+class FloatAttr : public Attribute {
+public:
+ static FloatAttr *get(double value, MLIRContext *context);
+
+ double getValue() const {
+ return value;
+ }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const Attribute *attr) {
+ return attr->getKind() == Kind::Float;
+ }
+private:
+ FloatAttr(double value) : Attribute(Kind::Float), value(value) {}
+ double value;
+};
+
+class StringAttr : public Attribute {
+public:
+ static StringAttr *get(StringRef bytes, MLIRContext *context);
+
+ StringRef getValue() const {
+ return value;
+ }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const Attribute *attr) {
+ return attr->getKind() == Kind::String;
+ }
+private:
+ StringAttr(StringRef value) : Attribute(Kind::String), value(value) {}
+ StringRef value;
+};
+
+class ArrayAttr : public Attribute {
+public:
+ static ArrayAttr *get(ArrayRef<Attribute*> value, MLIRContext *context);
+
+ ArrayRef<Attribute*> getValue() const {
+ return value;
+ }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static bool classof(const Attribute *attr) {
+ return attr->getKind() == Kind::Array;
+ }
+private:
+ ArrayAttr(ArrayRef<Attribute*> value) : Attribute(Kind::Array), value(value){}
+ ArrayRef<Attribute*> value;
+};
+
+} // end namespace mlir.
+
+#endif
+
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 4042b3e..e4c8c8c 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -56,7 +56,6 @@
// TODO: MemRef types.
};
-
/// Return the classification for this type.
Kind getKind() const {
return kind;
@@ -128,7 +127,6 @@
PrimitiveType(Kind kind, MLIRContext *context);
};
-
inline PrimitiveType *Type::getAffineInt(MLIRContext *ctx) {
return PrimitiveType::get(Kind::AffineInt, ctx);
}
diff --git a/include/mlir/Support/STLExtras.h b/include/mlir/Support/STLExtras.h
index da0b62e..905694b 100644
--- a/include/mlir/Support/STLExtras.h
+++ b/include/mlir/Support/STLExtras.h
@@ -23,6 +23,9 @@
#ifndef MLIR_SUPPORT_STLEXTRAS_H
#define MLIR_SUPPORT_STLEXTRAS_H
+#include "mlir/Support/LLVM.h"
+#include <tuple>
+
namespace mlir {
/// An STL-style algorithm similar to std::for_each that applies a second
@@ -56,6 +59,77 @@
interleave(c.begin(), c.end(), each_fn, between_fn);
}
-} // end namespace swift
+} // end namespace mlir
+
+// Allow tuples to be usable as DenseMap keys.
+// TODO: Move this to upstream LLVM.
+
+/// Simplistic combination of 32-bit hash values into 32-bit hash values.
+/// This function is taken from llvm/ADT/DenseMapInfo.h.
+static inline unsigned llvm_combineHashValue(unsigned a, unsigned b) {
+ uint64_t key = (uint64_t)a << 32 | (uint64_t)b;
+ key += ~(key << 32);
+ key ^= (key >> 22);
+ key += ~(key << 13);
+ key ^= (key >> 8);
+ key += (key << 3);
+ key ^= (key >> 15);
+ key += ~(key << 27);
+ key ^= (key >> 31);
+ return (unsigned)key;
+}
+
+namespace llvm {
+template<typename ...Ts>
+struct DenseMapInfo<std::tuple<Ts...> > {
+ typedef std::tuple<Ts...> Tuple;
+
+ static inline Tuple getEmptyKey() {
+ return Tuple(DenseMapInfo<Ts>::getEmptyKey()...);
+ }
+
+ static inline Tuple getTombstoneKey() {
+ return Tuple(DenseMapInfo<Ts>::getTombstoneKey()...);
+ }
+
+ template<unsigned I>
+ static unsigned getHashValueImpl(const Tuple& values, std::false_type) {
+ typedef typename std::tuple_element<I, Tuple>::type EltType;
+ std::integral_constant<bool, I+1 == sizeof...(Ts)> atEnd;
+ return llvm_combineHashValue(
+ DenseMapInfo<EltType>::getHashValue(std::get<I>(values)),
+ getHashValueImpl<I+1>(values, atEnd));
+ }
+
+ template<unsigned I>
+ static unsigned getHashValueImpl(const Tuple& values, std::true_type) {
+ return 0;
+ }
+
+ static unsigned getHashValue(const std::tuple<Ts...>& values) {
+ std::integral_constant<bool, 0 == sizeof...(Ts)> atEnd;
+ return getHashValueImpl<0>(values, atEnd);
+ }
+
+ template<unsigned I>
+ static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::false_type) {
+ typedef typename std::tuple_element<I, Tuple>::type EltType;
+ std::integral_constant<bool, I+1 == sizeof...(Ts)> atEnd;
+ return DenseMapInfo<EltType>::isEqual(std::get<I>(lhs), std::get<I>(rhs))
+ && isEqualImpl<I+1>(lhs, rhs, atEnd);
+ }
+
+ template<unsigned I>
+ static bool isEqualImpl(const Tuple &lhs, const Tuple &rhs, std::true_type) {
+ return true;
+ }
+
+ static bool isEqual(const Tuple &lhs, const Tuple &rhs) {
+ std::integral_constant<bool, 0 == sizeof...(Ts)> atEnd;
+ return isEqualImpl<0>(lhs, rhs, atEnd);
+ }
+};
+
+}
#endif // MLIR_SUPPORT_STLEXTRAS_H
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
new file mode 100644
index 0000000..0a773f1
--- /dev/null
+++ b/lib/IR/Attributes.cpp
@@ -0,0 +1,57 @@
+//===- Attributes.cpp - MLIR Attribute Implementation ---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Support/STLExtras.h"
+using namespace mlir;
+
+void Attribute::print(raw_ostream &os) const {
+ switch (getKind()) {
+ case Kind::Bool:
+ os << (cast<BoolAttr>(this)->getValue() ? "true" : "false");
+ break;
+ case Kind::Integer:
+ os << cast<IntegerAttr>(this)->getValue();
+ break;
+ case Kind::Float:
+ // FIXME: this isn't precise, we should print with a hex format.
+ os << cast<FloatAttr>(this)->getValue();
+ break;
+ case Kind::String:
+ // FIXME: should escape the string.
+ os << '"' << cast<StringAttr>(this)->getValue() << '"';
+ break;
+ case Kind::Array: {
+ auto elts = cast<ArrayAttr>(this)->getValue();
+ if (elts.empty())
+ os << "[]";
+ else {
+ os << "[ ";
+ interleave(elts,
+ [&](Attribute *attr) { attr->print(os); },
+ [&]() { os << ", "; });
+ os << " ]";
+ }
+ break;
+ }
+ }
+}
+
+void Attribute::dump() const {
+ print(llvm::errs());
+}
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index d148523..3524909 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -16,11 +16,12 @@
// =============================================================================
#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Identifier.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
#include "mlir/IR/Types.h"
-#include "mlir/Support/LLVM.h"
+#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Allocator.h"
@@ -50,23 +51,21 @@
struct AffineMapKeyInfo : DenseMapInfo<AffineMap *> {
// Affine maps are uniqued based on their dim/symbol counts and affine
// expressions.
- using KeyTy =
- std::pair<std::pair<unsigned, unsigned>, ArrayRef<AffineExpr *>>;
+ using KeyTy = std::tuple<unsigned, unsigned, ArrayRef<AffineExpr *>>;
using DenseMapInfo<AffineMap *>::getHashValue;
using DenseMapInfo<AffineMap *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
- key.first.first, key.first.second,
- hash_combine_range(key.second.begin(), key.second.end()));
+ std::get<0>(key), std::get<1>(key),
+ hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()));
}
static bool isEqual(const KeyTy &lhs, const AffineMap *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
- return lhs == KeyTy(std::pair<unsigned, unsigned>(rhs->getNumDims(),
- rhs->getNumSymbols()),
- rhs->getResults());
+ return lhs == std::make_tuple(rhs->getNumDims(), rhs->getNumSymbols(),
+ rhs->getResults());
}
};
@@ -88,6 +87,7 @@
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
}
};
+
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
// Ranked tensors are uniqued based on their element type and shape.
using KeyTy = std::pair<Type*, ArrayRef<int>>;
@@ -106,6 +106,23 @@
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
}
};
+
+struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr*> {
+ // Array attributes are uniqued based on their elements.
+ using KeyTy = ArrayRef<Attribute*>;
+ using DenseMapInfo<ArrayAttr*>::getHashValue;
+ using DenseMapInfo<ArrayAttr*>::isEqual;
+
+ static unsigned getHashValue(KeyTy key) {
+ return hash_combine_range(key.begin(), key.end());
+ }
+
+ static bool isEqual(const KeyTy &lhs, const ArrayAttr *rhs) {
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs == rhs->getValue();
+ }
+};
} // end anonymous namespace.
@@ -129,8 +146,7 @@
// Affine binary op expression uniquing. Figure out uniquing of dimensional
// or symbolic identifiers.
- // std::tuple doesn't work with DenseMap!, using nested pair.
- DenseMap<std::pair<unsigned, std::pair<AffineExpr *, AffineExpr *>>,
+ DenseMap<std::tuple<unsigned, AffineExpr *, AffineExpr *>,
AffineBinaryOpExpr *>
affineExprs;
@@ -153,6 +169,13 @@
/// Unranked tensor type uniquing.
DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
+ // Attribute uniquing.
+ BoolAttr *boolAttrs[2] = { nullptr };
+ DenseMap<int64_t, IntegerAttr*> integerAttrs;
+ DenseMap<int64_t, FloatAttr*> floatAttrs;
+ StringMap<StringAttr*> stringAttrs;
+ using ArrayAttrSet = DenseSet<ArrayAttr*, ArrayAttrKeyInfo>;
+ ArrayAttrSet arrayAttrs;
public:
MLIRContextImpl() : identifiers(allocator) {}
@@ -176,7 +199,7 @@
//===----------------------------------------------------------------------===//
-// Identifier
+// Identifier uniquing
//===----------------------------------------------------------------------===//
/// Return an identifier for the specified string.
@@ -191,7 +214,7 @@
}
//===----------------------------------------------------------------------===//
-// Types
+// Type uniquing
//===----------------------------------------------------------------------===//
PrimitiveType *PrimitiveType::get(Kind kind, MLIRContext *context) {
@@ -323,22 +346,101 @@
auto &impl = context->getImpl();
// Look to see if we already have this unranked tensor type.
- auto existing = impl.unrankedTensors.insert({elementType, nullptr});
+ auto *&result = impl.unrankedTensors[elementType];
// If we already have it, return that value.
- if (!existing.second)
- return existing.first->second;
+ if (result)
+ return result;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<UnrankedTensorType>();
+ result = impl.allocator.Allocate<UnrankedTensorType>();
// Initialize the memory using placement new.
new (result) UnrankedTensorType(elementType, context);
+ return result;
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute uniquing
+//===----------------------------------------------------------------------===//
+
+BoolAttr *BoolAttr::get(bool value, MLIRContext *context) {
+ auto *&result = context->getImpl().boolAttrs[value];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<BoolAttr>();
+ new (result) BoolAttr(value);
+ return result;
+}
+
+IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
+ auto *&result = context->getImpl().integerAttrs[value];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<IntegerAttr>();
+ new (result) IntegerAttr(value);
+ return result;
+}
+
+FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
+ // We hash based on the bit representation of the double to ensure we don't
+ // merge things like -0.0 and 0.0 in the hash comparison.
+ union {
+ double floatValue;
+ int64_t intValue;
+ };
+ floatValue = value;
+
+ auto *&result = context->getImpl().floatAttrs[intValue];
+ if (result)
+ return result;
+
+ result = context->getImpl().allocator.Allocate<FloatAttr>();
+ new (result) FloatAttr(value);
+ return result;
+}
+
+StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
+ auto it = context->getImpl().stringAttrs.insert({bytes, nullptr}).first;
+
+ if (it->second)
+ return it->second;
+
+ auto result = context->getImpl().allocator.Allocate<StringAttr>();
+ new (result) StringAttr(it->first());
+ it->second = result;
+ return result;
+}
+
+ArrayAttr *ArrayAttr::get(ArrayRef<Attribute*> value, MLIRContext *context) {
+ auto &impl = context->getImpl();
+
+ // Look to see if we already have this.
+ auto existing = impl.arrayAttrs.insert_as(nullptr, value);
+
+ // If we already have it, return that value.
+ if (!existing.second)
+ return *existing.first;
+
+ // On the first use, we allocate them into the bump pointer.
+ auto *result = impl.allocator.Allocate<ArrayAttr>();
+
+ // Copy the elements into the bump pointer.
+ value = impl.copyInto(value);
+
+ // Initialize the memory using placement new.
+ new (result) ArrayAttr(value);
// Cache and return it.
- return existing.first->second = result;
+ return *existing.first = result;
}
+//===----------------------------------------------------------------------===//
+// AffineMap and AffineExpr uniquing
+//===----------------------------------------------------------------------===//
+
AffineMap *AffineMap::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr *> results,
MLIRContext *context) {
@@ -348,8 +450,7 @@
auto &impl = context->getImpl();
// Check if we already have this affine map.
- AffineMapKeyInfo::KeyTy key(
- std::pair<unsigned, unsigned>(dimCount, symbolCount), results);
+ auto key = std::make_tuple(dimCount, symbolCount, results);
auto existing = impl.affineMaps.insert_as(nullptr, key);
// If we already have it, return that value.
@@ -369,7 +470,6 @@
return *existing.first = res;
}
-// TODO(bondhugula): complete uniqu'ing of remaining AffinExpr sub-classes
AffineBinaryOpExpr *AffineBinaryOpExpr::get(AffineExpr::Kind kind,
AffineExpr *lhsOperand,
AffineExpr *rhsOperand,
@@ -377,10 +477,8 @@
auto &impl = context->getImpl();
// Check if we already have this affine expression.
- auto key = std::pair<unsigned, std::pair<AffineExpr *, AffineExpr *>>(
- (unsigned)kind,
- std::pair<AffineExpr *, AffineExpr *>(lhsOperand, rhsOperand));
- auto *&result = impl.affineExprs[key];
+ auto keyValue = std::make_tuple((unsigned)kind, lhsOperand, rhsOperand);
+ auto *&result = impl.affineExprs[keyValue];
// If we already have it, return that value.
if (!result) {
@@ -393,6 +491,7 @@
return result;
}
+// TODO(bondhugula): complete uniquing of remaining AffineExpr sub-classes.
AffineAddExpr *AffineAddExpr::get(AffineExpr *lhsOperand,
AffineExpr *rhsOperand,
MLIRContext *context) {