Implement value type abstraction for types.
This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast.
PiperOrigin-RevId: 219372163
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 1b3c24f..1904a63 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -118,15 +118,15 @@
return tripCountExpr.getLargestKnownDivisor();
}
-bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType,
+bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
ArrayRef<MLValue *> indices, unsigned dim) {
- assert(indices.size() == memRefType->getRank());
+ assert(indices.size() == memRefType.getRank());
assert(dim < indices.size());
- auto layoutMap = memRefType->getAffineMaps();
- assert(memRefType->getAffineMaps().size() <= 1);
+ auto layoutMap = memRefType.getAffineMaps();
+ assert(memRefType.getAffineMaps().size() <= 1);
// TODO(ntv): remove dependency on Builder once we support non-identity
// layout map.
- Builder b(memRefType->getContext());
+ Builder b(memRefType.getContext());
assert(layoutMap.empty() ||
layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
(void)layoutMap;
@@ -170,7 +170,7 @@
using namespace functional;
auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
memoryOp->getIndices());
- auto *memRefType = memoryOp->getMemRefType();
+ auto memRefType = memoryOp->getMemRefType();
for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
if (fastestVaryingDim == (numIndices - 1) - d) {
continue;
@@ -184,8 +184,8 @@
template <typename LoadOrStoreOpPointer>
static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
- auto *memRefType = memoryOp->getMemRefType();
- return isa<VectorType>(memRefType->getElementType());
+ auto memRefType = memoryOp->getMemRefType();
+ return memRefType.getElementType().template isa<VectorType>();
}
bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
diff --git a/lib/Analysis/Verifier.cpp b/lib/Analysis/Verifier.cpp
index bfbcb16..0dd030d 100644
--- a/lib/Analysis/Verifier.cpp
+++ b/lib/Analysis/Verifier.cpp
@@ -195,7 +195,7 @@
// Verify that the argument list of the function and the arg list of the first
// block line up.
- auto fnInputTypes = fn.getType()->getInputs();
+ auto fnInputTypes = fn.getType().getInputs();
if (fnInputTypes.size() != firstBB->getNumArguments())
return failure("first block of cfgfunc must have " +
Twine(fnInputTypes.size()) +
@@ -306,7 +306,7 @@
bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
// Verify that the return operands match the results of the function.
- auto results = fn.getType()->getResults();
+ auto results = fn.getType().getResults();
if (inst.getNumOperands() != results.size())
return failure("return has " + Twine(inst.getNumOperands()) +
" operands, but enclosing function returns " +
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 454a28a..cb5e96f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -122,7 +122,7 @@
void visitForStmt(const ForStmt *forStmt);
void visitIfStmt(const IfStmt *ifStmt);
void visitOperationStmt(const OperationStmt *opStmt);
- void visitType(const Type *type);
+ void visitType(Type type);
void visitAttribute(Attribute attr);
void visitOperation(const Operation *op);
@@ -135,16 +135,16 @@
} // end anonymous namespace
// TODO Support visiting other types/instructions when implemented.
-void ModuleState::visitType(const Type *type) {
- if (auto *funcType = dyn_cast<FunctionType>(type)) {
+void ModuleState::visitType(Type type) {
+ if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
- for (auto *input : funcType->getInputs())
+ for (auto input : funcType.getInputs())
visitType(input);
- for (auto *result : funcType->getResults())
+ for (auto result : funcType.getResults())
visitType(result);
- } else if (auto *memref = dyn_cast<MemRefType>(type)) {
+ } else if (auto memref = type.dyn_cast<MemRefType>()) {
// Visit affine maps in memref type.
- for (auto map : memref->getAffineMaps()) {
+ for (auto map : memref.getAffineMaps()) {
recordAffineMapReference(map);
}
}
@@ -271,7 +271,7 @@
void print(const Module *module);
void printFunctionReference(const Function *func);
void printAttribute(Attribute attr);
- void printType(const Type *type);
+ void printType(Type type);
void print(const Function *fn);
void print(const ExtFunction *fn);
void print(const CFGFunction *fn);
@@ -290,7 +290,7 @@
void printFunctionAttributes(const Function *fn);
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<const char *> elidedAttrs = {});
- void printFunctionResultType(const FunctionType *type);
+ void printFunctionResultType(FunctionType type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const;
@@ -489,9 +489,9 @@
}
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
- auto *type = attr.getType();
- auto shape = type->getShape();
- auto rank = type->getRank();
+ auto type = attr.getType();
+ auto shape = type.getShape();
+ auto rank = type.getRank();
SmallVector<Attribute, 16> elements;
attr.getValues(elements);
@@ -541,8 +541,8 @@
os << ']';
}
-void ModulePrinter::printType(const Type *type) {
- switch (type->getKind()) {
+void ModulePrinter::printType(Type type) {
+ switch (type.getKind()) {
case Type::Kind::Index:
os << "index";
return;
@@ -581,71 +581,71 @@
return;
case Type::Kind::Integer: {
- auto *integer = cast<IntegerType>(type);
- os << 'i' << integer->getWidth();
+ auto integer = type.cast<IntegerType>();
+ os << 'i' << integer.getWidth();
return;
}
case Type::Kind::Function: {
- auto *func = cast<FunctionType>(type);
+ auto func = type.cast<FunctionType>();
os << '(';
- interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
+ interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
os << ") -> ";
- auto results = func->getResults();
+ auto results = func.getResults();
if (results.size() == 1)
- os << *results[0];
+ os << results[0];
else {
os << '(';
- interleaveComma(results, [&](Type *type) { printType(type); });
+ interleaveComma(results, [&](Type type) { printType(type); });
os << ')';
}
return;
}
case Type::Kind::Vector: {
- auto *v = cast<VectorType>(type);
+ auto v = type.cast<VectorType>();
os << "vector<";
- for (auto dim : v->getShape())
+ for (auto dim : v.getShape())
os << dim << 'x';
- os << *v->getElementType() << '>';
+ os << v.getElementType() << '>';
return;
}
case Type::Kind::RankedTensor: {
- auto *v = cast<RankedTensorType>(type);
+ auto v = type.cast<RankedTensorType>();
os << "tensor<";
- for (auto dim : v->getShape()) {
+ for (auto dim : v.getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
- os << *v->getElementType() << '>';
+ os << v.getElementType() << '>';
return;
}
case Type::Kind::UnrankedTensor: {
- auto *v = cast<UnrankedTensorType>(type);
+ auto v = type.cast<UnrankedTensorType>();
os << "tensor<*x";
- printType(v->getElementType());
+ printType(v.getElementType());
os << '>';
return;
}
case Type::Kind::MemRef: {
- auto *v = cast<MemRefType>(type);
+ auto v = type.cast<MemRefType>();
os << "memref<";
- for (auto dim : v->getShape()) {
+ for (auto dim : v.getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
- printType(v->getElementType());
- for (auto map : v->getAffineMaps()) {
+ printType(v.getElementType());
+ for (auto map : v.getAffineMaps()) {
os << ", ";
printAffineMapReference(map);
}
// Only print the memory space if it is the non-default one.
- if (v->getMemorySpace())
- os << ", " << v->getMemorySpace();
+ if (v.getMemorySpace())
+ os << ", " << v.getMemorySpace();
os << '>';
return;
}
@@ -842,18 +842,18 @@
// Function printing
//===----------------------------------------------------------------------===//
-void ModulePrinter::printFunctionResultType(const FunctionType *type) {
- switch (type->getResults().size()) {
+void ModulePrinter::printFunctionResultType(FunctionType type) {
+ switch (type.getResults().size()) {
case 0:
break;
case 1:
os << " -> ";
- printType(type->getResults()[0]);
+ printType(type.getResults()[0]);
break;
default:
os << " -> (";
- interleaveComma(type->getResults(),
- [&](Type *eltType) { printType(eltType); });
+ interleaveComma(type.getResults(),
+ [&](Type eltType) { printType(eltType); });
os << ')';
break;
}
@@ -871,8 +871,7 @@
auto type = fn->getType();
os << "@" << fn->getName() << '(';
- interleaveComma(type->getInputs(),
- [&](Type *eltType) { printType(eltType); });
+ interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
os << ')';
printFunctionResultType(type);
@@ -937,7 +936,7 @@
// Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; }
- void printType(const Type *type) { ModulePrinter::printType(type); }
+ void printType(Type type) { ModulePrinter::printType(type); }
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
void printAffineMap(AffineMap map) {
return ModulePrinter::printAffineMapReference(map);
@@ -974,10 +973,10 @@
if (auto *op = value->getDefiningOperation()) {
if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
// i1 constants get special names.
- if (intOp->getType()->isInteger(1)) {
+ if (intOp->getType().isInteger(1)) {
specialName << (intOp->getValue() ? "true" : "false");
} else {
- specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
+ specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
}
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
@@ -1579,7 +1578,7 @@
void Type::print(raw_ostream &os) const {
ModuleState state(getContext());
- ModulePrinter(os, state).printType(this);
+ ModulePrinter(os, state).printType(*this);
}
void Type::dump() const { print(llvm::errs()); }
diff --git a/lib/IR/AttributeDetail.h b/lib/IR/AttributeDetail.h
index a0e9afb..63ad544 100644
--- a/lib/IR/AttributeDetail.h
+++ b/lib/IR/AttributeDetail.h
@@ -26,6 +26,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
@@ -86,7 +87,7 @@
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
- Type *value;
+ Type value;
};
/// An attribute representing a reference to a function.
@@ -96,7 +97,7 @@
/// A base attribute representing a reference to a vector or tensor constant.
struct ElementsAttributeStorage : public AttributeStorage {
- VectorOrTensorType *type;
+ VectorOrTensorType type;
};
/// An attribute representing a reference to a vector or tensor constant,
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
index 34312b8..58b5b90 100644
--- a/lib/IR/Attributes.cpp
+++ b/lib/IR/Attributes.cpp
@@ -75,9 +75,7 @@
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
-Type *TypeAttr::getValue() const {
- return static_cast<ImplType *>(attr)->value;
-}
+Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
@@ -85,11 +83,11 @@
return static_cast<ImplType *>(attr)->value;
}
-FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
+FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
-VectorOrTensorType *ElementsAttr::getType() const {
+VectorOrTensorType ElementsAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
}
@@ -166,8 +164,8 @@
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
- auto elementNum = getType()->getNumElements();
- auto context = getType()->getContext();
+ auto elementNum = getType().getNumElements();
+ auto context = getType().getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
@@ -192,8 +190,8 @@
: DenseElementsAttr(ptr) {}
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
- auto elementNum = getType()->getNumElements();
- auto context = getType()->getContext();
+ auto elementNum = getType().getNumElements();
+ auto context = getType().getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index bb8ac75..29a5ce1 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -33,18 +33,18 @@
// Argument list management.
//===----------------------------------------------------------------------===//
-BBArgument *BasicBlock::addArgument(Type *type) {
+BBArgument *BasicBlock::addArgument(Type type) {
auto *arg = new BBArgument(type, this);
arguments.push_back(arg);
return arg;
}
/// Add one argument to the argument list for each type specified in the list.
-auto BasicBlock::addArguments(ArrayRef<Type *> types)
+auto BasicBlock::addArguments(ArrayRef<Type> types)
-> llvm::iterator_range<args_iterator> {
arguments.reserve(arguments.size() + types.size());
auto initialSize = arguments.size();
- for (auto *type : types) {
+ for (auto type : types) {
addArgument(type);
}
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 22d749a..906b580 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -52,59 +52,58 @@
// Types.
//===----------------------------------------------------------------------===//
-FloatType *Builder::getBF16Type() { return Type::getBF16(context); }
+FloatType Builder::getBF16Type() { return Type::getBF16(context); }
-FloatType *Builder::getF16Type() { return Type::getF16(context); }
+FloatType Builder::getF16Type() { return Type::getF16(context); }
-FloatType *Builder::getF32Type() { return Type::getF32(context); }
+FloatType Builder::getF32Type() { return Type::getF32(context); }
-FloatType *Builder::getF64Type() { return Type::getF64(context); }
+FloatType Builder::getF64Type() { return Type::getF64(context); }
-OtherType *Builder::getIndexType() { return Type::getIndex(context); }
+OtherType Builder::getIndexType() { return Type::getIndex(context); }
-OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
+OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
-OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
+OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
-OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
+OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
-OtherType *Builder::getTFComplex64Type() {
+OtherType Builder::getTFComplex64Type() {
return Type::getTFComplex64(context);
}
-OtherType *Builder::getTFComplex128Type() {
+OtherType Builder::getTFComplex128Type() {
return Type::getTFComplex128(context);
}
-OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
+OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
-OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
+OtherType Builder::getTFStringType() { return Type::getTFString(context); }
-IntegerType *Builder::getIntegerType(unsigned width) {
+IntegerType Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context);
}
-FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
- ArrayRef<Type *> results) {
+FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
+ ArrayRef<Type> results) {
return FunctionType::get(inputs, results, context);
}
-MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
- ArrayRef<AffineMap> affineMapComposition,
- unsigned memorySpace) {
+MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType,
+ ArrayRef<AffineMap> affineMapComposition,
+ unsigned memorySpace) {
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
}
-VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
+VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) {
return VectorType::get(shape, elementType);
}
-RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
- Type *elementType) {
+RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) {
return RankedTensorType::get(shape, elementType);
}
-UnrankedTensorType *Builder::getTensorType(Type *elementType) {
+UnrankedTensorType Builder::getTensorType(Type elementType) {
return UnrankedTensorType::get(elementType);
}
@@ -144,7 +143,7 @@
return IntegerSetAttr::get(set);
}
-TypeAttr Builder::getTypeAttr(Type *type) {
+TypeAttr Builder::getTypeAttr(Type type) {
return TypeAttr::get(type, context);
}
@@ -152,23 +151,23 @@
return FunctionAttr::get(value, context);
}
-ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
Attribute elt) {
return SplatElementsAttr::get(type, elt);
}
-ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
ArrayRef<char> data) {
return DenseElementsAttr::get(type, data);
}
-ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
return SparseElementsAttr::get(type, indices, values);
}
-ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
StringRef bytes) {
return OpaqueElementsAttr::get(type, bytes);
}
@@ -296,7 +295,7 @@
OperationStmt *MLFuncBuilder::createOperation(Location *location,
OperationName name,
ArrayRef<MLValue *> operands,
- ArrayRef<Type *> types,
+ ArrayRef<Type> types,
ArrayRef<NamedAttribute> attrs) {
auto *op = OperationStmt::create(location, name, operands, types, attrs,
getContext());
diff --git a/lib/IR/BuiltinOps.cpp b/lib/IR/BuiltinOps.cpp
index 542e67e..e4bca03 100644
--- a/lib/IR/BuiltinOps.cpp
+++ b/lib/IR/BuiltinOps.cpp
@@ -63,7 +63,7 @@
numDims = opInfos.size();
// Parse the optional symbol operands.
- auto *affineIntTy = parser->getBuilder().getIndexType();
+ auto affineIntTy = parser->getBuilder().getIndexType();
if (parser->parseOperandList(opInfos, -1,
OpAsmParser::Delimiter::OptionalSquare) ||
parser->resolveOperands(opInfos, affineIntTy, operands))
@@ -84,7 +84,7 @@
bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
- auto *affineIntTy = builder.getIndexType();
+ auto affineIntTy = builder.getIndexType();
AffineMapAttr mapAttr;
unsigned numDims;
@@ -171,7 +171,7 @@
/// Builds a constant op with the specified attribute value and result type.
void ConstantOp::build(Builder *builder, OperationState *result,
- Attribute value, Type *type) {
+ Attribute value, Type type) {
result->addAttribute("value", value);
result->types.push_back(type);
}
@@ -181,12 +181,12 @@
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
if (!getValue().isa<FunctionAttr>())
- *p << " : " << *getType();
+ *p << " : " << getType();
}
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute valueAttr;
- Type *type;
+ Type type;
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes))
@@ -208,33 +208,33 @@
if (!value)
return emitOpError("requires a 'value' attribute");
- auto *type = this->getType();
- if (isa<IntegerType>(type) || type->isIndex()) {
+ auto type = this->getType();
+ if (type.isa<IntegerType>() || type.isIndex()) {
if (!value.isa<IntegerAttr>())
return emitOpError(
"requires 'value' to be an integer for an integer result type");
return false;
}
- if (isa<FloatType>(type)) {
+ if (type.isa<FloatType>()) {
if (!value.isa<FloatAttr>())
return emitOpError("requires 'value' to be a floating point constant");
return false;
}
- if (isa<VectorOrTensorType>(type)) {
+ if (type.isa<VectorOrTensorType>()) {
if (!value.isa<ElementsAttr>())
return emitOpError("requires 'value' to be a vector/tensor constant");
return false;
}
- if (type->isTFString()) {
+ if (type.isTFString()) {
if (!value.isa<StringAttr>())
return emitOpError("requires 'value' to be a string constant");
return false;
}
- if (isa<FunctionType>(type)) {
+ if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference");
return false;
@@ -251,19 +251,19 @@
}
void ConstantFloatOp::build(Builder *builder, OperationState *result,
- const APFloat &value, FloatType *type) {
+ const APFloat &value, FloatType type) {
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
}
bool ConstantFloatOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
- isa<FloatType>(op->getResult(0)->getType());
+ op->getResult(0)->getType().isa<FloatType>();
}
/// ConstantIntOp only matches values whose result type is an IntegerType.
bool ConstantIntOp::isClassFor(const Operation *op) {
return ConstantOp::isClassFor(op) &&
- isa<IntegerType>(op->getResult(0)->getType());
+ op->getResult(0)->getType().isa<IntegerType>();
}
void ConstantIntOp::build(Builder *builder, OperationState *result,
@@ -275,14 +275,14 @@
/// Build a constant int op producing an integer with the specified type,
/// which must be an integer type.
void ConstantIntOp::build(Builder *builder, OperationState *result,
- int64_t value, Type *type) {
- assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
+ int64_t value, Type type) {
+ assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
}
/// ConstantIndexOp only matches values whose result type is Index.
bool ConstantIndexOp::isClassFor(const Operation *op) {
- return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
+ return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
}
void ConstantIndexOp::build(Builder *builder, OperationState *result,
@@ -302,7 +302,7 @@
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
- SmallVector<Type *, 2> types;
+ SmallVector<Type, 2> types;
llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
@@ -330,7 +330,7 @@
// The operand number and types must match the function signature.
MLFunction *function = cast<MLFunction>(block);
- const auto &results = function->getType()->getResults();
+ const auto &results = function->getType().getResults();
if (stmt->getNumOperands() != results.size())
return emitOpError("has " + Twine(stmt->getNumOperands()) +
" operands, but enclosing function returns " +
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index efeb16b..70c0e12 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -28,8 +28,8 @@
using namespace mlir;
Function::Function(Kind kind, Location *location, StringRef name,
- FunctionType *type, ArrayRef<NamedAttribute> attrs)
- : nameAndKind(Identifier::get(name, type->getContext()), kind),
+ FunctionType type, ArrayRef<NamedAttribute> attrs)
+ : nameAndKind(Identifier::get(name, type.getContext()), kind),
location(location), type(type) {
this->attrs = AttributeListStorage::get(attrs, getContext());
}
@@ -46,7 +46,7 @@
return {};
}
-MLIRContext *Function::getContext() const { return getType()->getContext(); }
+MLIRContext *Function::getContext() const { return getType().getContext(); }
/// Delete this object.
void Function::destroy() {
@@ -159,7 +159,7 @@
// ExtFunction implementation.
//===----------------------------------------------------------------------===//
-ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
+ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::ExtFunc, location, name, type, attrs) {}
@@ -167,7 +167,7 @@
// CFGFunction implementation.
//===----------------------------------------------------------------------===//
-CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type,
+CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::CFGFunc, location, name, type, attrs) {}
@@ -188,9 +188,9 @@
/// Create a new MLFunction with the specific fields.
MLFunction *MLFunction::create(Location *location, StringRef name,
- FunctionType *type,
+ FunctionType type,
ArrayRef<NamedAttribute> attrs) {
- const auto &argTypes = type->getInputs();
+ const auto &argTypes = type.getInputs();
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
void *rawMem = malloc(byteSize);
@@ -204,7 +204,7 @@
return function;
}
-MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type,
+MLFunction::MLFunction(Location *location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::MLFunc, location, name, type, attrs),
StmtBlock(StmtBlockKind::MLFunc) {}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 422636b..d2f49dd 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -143,7 +143,7 @@
/// Create a new OperationInst with the specified fields.
OperationInst *OperationInst::create(Location *location, OperationName name,
ArrayRef<CFGValue *> operands,
- ArrayRef<Type *> resultTypes,
+ ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context) {
auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
@@ -167,7 +167,7 @@
OperationInst *OperationInst::clone() const {
SmallVector<CFGValue *, 8> operands;
- SmallVector<Type *, 8> resultTypes;
+ SmallVector<Type, 8> resultTypes;
// Put together the operands and results.
for (auto *operand : getOperands())
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 0a2e941..8811f7b 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -21,6 +21,7 @@
#include "AttributeDetail.h"
#include "AttributeListStorage.h"
#include "IntegerSetDetail.h"
+#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
@@ -44,11 +45,11 @@
using namespace llvm;
namespace {
-struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
+struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> {
// Functions are uniqued based on their inputs and results.
- using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>;
- using DenseMapInfo<FunctionType *>::getHashValue;
- using DenseMapInfo<FunctionType *>::isEqual;
+ using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
+ using DenseMapInfo<FunctionTypeStorage *>::getHashValue;
+ using DenseMapInfo<FunctionTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
@@ -56,7 +57,7 @@
hash_combine_range(key.second.begin(), key.second.end()));
}
- static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
+ static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
@@ -109,65 +110,64 @@
}
};
-struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
+struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> {
// Vectors are uniqued based on their element type and shape.
- using KeyTy = std::pair<Type *, ArrayRef<int>>;
- using DenseMapInfo<VectorType *>::getHashValue;
- using DenseMapInfo<VectorType *>::isEqual;
+ using KeyTy = std::pair<Type, ArrayRef<int>>;
+ using DenseMapInfo<VectorTypeStorage *>::getHashValue;
+ using DenseMapInfo<VectorTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
- DenseMapInfo<Type *>::getHashValue(key.first),
+ DenseMapInfo<Type>::getHashValue(key.first),
hash_combine_range(key.second.begin(), key.second.end()));
}
- static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
+ static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
- return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+ return lhs == KeyTy(rhs->elementType, rhs->getShape());
}
};
-struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> {
+struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> {
// Ranked tensors are uniqued based on their element type and shape.
- using KeyTy = std::pair<Type *, ArrayRef<int>>;
- using DenseMapInfo<RankedTensorType *>::getHashValue;
- using DenseMapInfo<RankedTensorType *>::isEqual;
+ using KeyTy = std::pair<Type, ArrayRef<int>>;
+ using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue;
+ using DenseMapInfo<RankedTensorTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
- DenseMapInfo<Type *>::getHashValue(key.first),
+ DenseMapInfo<Type>::getHashValue(key.first),
hash_combine_range(key.second.begin(), key.second.end()));
}
- static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
+ static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
- return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+ return lhs == KeyTy(rhs->elementType, rhs->getShape());
}
};
-struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
+struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
// MemRefs are uniqued based on their element type, shape, affine map
// composition, and memory space.
- using KeyTy =
- std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
- using DenseMapInfo<MemRefType *>::getHashValue;
- using DenseMapInfo<MemRefType *>::isEqual;
+ using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
+ using DenseMapInfo<MemRefTypeStorage *>::getHashValue;
+ using DenseMapInfo<MemRefTypeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
- DenseMapInfo<Type *>::getHashValue(std::get<0>(key)),
+ DenseMapInfo<Type>::getHashValue(std::get<0>(key)),
hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
std::get<3>(key));
}
- static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) {
+ static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
- return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(),
- rhs->getAffineMaps(), rhs->getMemorySpace());
+ return lhs == std::make_tuple(rhs->elementType, rhs->getShape(),
+ rhs->getAffineMaps(), rhs->memorySpace);
}
};
@@ -221,7 +221,7 @@
};
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
- using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
+ using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
@@ -239,7 +239,7 @@
};
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
- using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
+ using KeyTy = std::pair<VectorOrTensorType, StringRef>;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
@@ -295,13 +295,14 @@
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
// Uniquing table for 'other' types.
- OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
- int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr};
+ OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
+ int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {
+ nullptr};
// Uniquing table for 'float' types.
- FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
- int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = {
- nullptr};
+ FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
+ int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] =
+ {nullptr};
// Affine map uniquing.
using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
@@ -324,26 +325,26 @@
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
/// Integer type uniquing.
- DenseMap<unsigned, IntegerType *> integers;
+ DenseMap<unsigned, IntegerTypeStorage *> integers;
/// Function type uniquing.
- using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>;
+ using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>;
FunctionTypeSet functions;
/// Vector type uniquing.
- using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>;
+ using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>;
VectorTypeSet vectors;
/// Ranked tensor type uniquing.
using RankedTensorTypeSet =
- DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>;
+ DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>;
RankedTensorTypeSet rankedTensors;
/// Unranked tensor type uniquing.
- DenseMap<Type *, UnrankedTensorType *> unrankedTensors;
+ DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors;
/// MemRef type uniquing.
- using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>;
+ using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>;
MemRefTypeSet memrefs;
// Attribute uniquing.
@@ -355,13 +356,12 @@
ArrayAttrSet arrayAttrs;
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
- DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
+ DenseMap<Type, TypeAttributeStorage *> typeAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
- DenseMap<std::pair<VectorOrTensorType *, Attribute>,
- SplatElementsAttributeStorage *>
+ DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
splatElementsAttrs;
using DenseElementsAttrSet =
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
@@ -369,7 +369,7 @@
using OpaqueElementsAttrSet =
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
OpaqueElementsAttrSet opaqueElementsAttrs;
- DenseMap<std::tuple<Type *, Attribute, Attribute>,
+ DenseMap<std::tuple<Type, Attribute, Attribute>,
SparseElementsAttributeStorage *>
sparseElementsAttrs;
@@ -556,19 +556,20 @@
// Type uniquing
//===----------------------------------------------------------------------===//
-IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
+IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
+ assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
auto &impl = context->getImpl();
auto *&result = impl.integers[width];
if (!result) {
- result = impl.allocator.Allocate<IntegerType>();
- new (result) IntegerType(width, context);
+ result = impl.allocator.Allocate<IntegerTypeStorage>();
+ new (result) IntegerTypeStorage{{Kind::Integer, context}, width};
}
return result;
}
-FloatType *FloatType::get(Kind kind, MLIRContext *context) {
+FloatType FloatType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
auto &impl = context->getImpl();
@@ -580,16 +581,16 @@
return entry;
// On the first use, we allocate them into the bump pointer.
- auto *ptr = impl.allocator.Allocate<FloatType>();
+ auto *ptr = impl.allocator.Allocate<FloatTypeStorage>();
// Initialize the memory using placement new.
- new (ptr) FloatType(kind, context);
+ new (ptr) FloatTypeStorage{{kind, context}};
// Cache and return it.
return entry = ptr;
}
-OtherType *OtherType::get(Kind kind, MLIRContext *context) {
+OtherType OtherType::get(Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
"Not an 'other' type kind");
auto &impl = context->getImpl();
@@ -600,18 +601,17 @@
return entry;
// On the first use, we allocate them into the bump pointer.
- auto *ptr = impl.allocator.Allocate<OtherType>();
+ auto *ptr = impl.allocator.Allocate<OtherTypeStorage>();
// Initialize the memory using placement new.
- new (ptr) OtherType(kind, context);
+ new (ptr) OtherTypeStorage{{kind, context}};
// Cache and return it.
return entry = ptr;
}
-FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
- ArrayRef<Type *> results,
- MLIRContext *context) {
+FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+ MLIRContext *context) {
auto &impl = context->getImpl();
// Look to see if we already have this function type.
@@ -623,32 +623,34 @@
return *existing.first;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<FunctionType>();
+ auto *result = impl.allocator.Allocate<FunctionTypeStorage>();
// Copy the inputs and results into the bump pointer.
- SmallVector<Type *, 16> types;
+ SmallVector<Type, 16> types;
types.reserve(inputs.size() + results.size());
types.append(inputs.begin(), inputs.end());
types.append(results.begin(), results.end());
- auto typesList = impl.copyInto(ArrayRef<Type *>(types));
+ auto typesList = impl.copyInto(ArrayRef<Type>(types));
// Initialize the memory using placement new.
- new (result)
- FunctionType(typesList.data(), inputs.size(), results.size(), context);
+ new (result) FunctionTypeStorage{
+ {Kind::Function, context, static_cast<unsigned int>(inputs.size())},
+ static_cast<unsigned int>(results.size()),
+ typesList.data()};
// Cache and return it.
return *existing.first = result;
}
-VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
+VectorType VectorType::get(ArrayRef<int> shape, Type elementType) {
assert(!shape.empty() && "vector types must have at least one dimension");
- assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
+ assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) &&
"vectors elements must be primitives");
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
return i < 0;
}) && "vector types must have static shape");
- auto *context = elementType->getContext();
+ auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this vector type.
@@ -660,21 +662,23 @@
return *existing.first;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<VectorType>();
+ auto *result = impl.allocator.Allocate<VectorTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
// Initialize the memory using placement new.
- new (result) VectorType(shape, elementType, context);
+ new (result) VectorTypeStorage{
+ {{Kind::Vector, context, static_cast<unsigned int>(shape.size())},
+ elementType},
+ shape.data()};
// Cache and return it.
return *existing.first = result;
}
-RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
- Type *elementType) {
- auto *context = elementType->getContext();
+RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
+ auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this ranked tensor type.
@@ -686,20 +690,23 @@
return *existing.first;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<RankedTensorType>();
+ auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
// Initialize the memory using placement new.
- new (result) RankedTensorType(shape, elementType, context);
+ new (result) RankedTensorTypeStorage{
+ {{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())},
+ elementType}},
+ shape.data()};
// Cache and return it.
return *existing.first = result;
}
-UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
- auto *context = elementType->getContext();
+UnrankedTensorType UnrankedTensorType::get(Type elementType) {
+ auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Look to see if we already have this unranked tensor type.
@@ -710,17 +717,18 @@
return result;
// On the first use, we allocate them into the bump pointer.
- result = impl.allocator.Allocate<UnrankedTensorType>();
+ result = impl.allocator.Allocate<UnrankedTensorTypeStorage>();
// Initialize the memory using placement new.
- new (result) UnrankedTensorType(elementType, context);
+ new (result) UnrankedTensorTypeStorage{
+ {{{Kind::UnrankedTensor, context}, elementType}}};
return result;
}
-MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
- ArrayRef<AffineMap> affineMapComposition,
- unsigned memorySpace) {
- auto *context = elementType->getContext();
+MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType,
+ ArrayRef<AffineMap> affineMapComposition,
+ unsigned memorySpace) {
+ auto *context = elementType.getContext();
auto &impl = context->getImpl();
// Drop the unbounded identity maps from the composition.
@@ -744,7 +752,7 @@
return *existing.first;
// On the first use, we allocate them into the bump pointer.
- auto *result = impl.allocator.Allocate<MemRefType>();
+ auto *result = impl.allocator.Allocate<MemRefTypeStorage>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
@@ -755,8 +763,13 @@
impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
// Initialize the memory using placement new.
- new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
- context);
+ new (result) MemRefTypeStorage{
+ {Kind::MemRef, context, static_cast<unsigned int>(shape.size())},
+ elementType,
+ shape.data(),
+ static_cast<unsigned int>(affineMapComposition.size()),
+ affineMapComposition.data(),
+ memorySpace};
// Cache and return it.
return *existing.first = result;
}
@@ -895,7 +908,7 @@
return result;
}
-TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
+TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
auto *&result = context->getImpl().typeAttrs[type];
if (result)
return result;
@@ -1009,9 +1022,9 @@
return *existing.first = result;
}
-SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
+SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
Attribute elt) {
- auto &impl = type->getContext()->getImpl();
+ auto &impl = type.getContext()->getImpl();
// Look to see if we already have this.
auto *&result = impl.splatElementsAttrs[{type, elt}];
@@ -1030,14 +1043,14 @@
return result;
}
-DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
+DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
ArrayRef<char> data) {
- auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
+ auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
(void)bitsRequired;
assert((bitsRequired <= data.size() * 8L) &&
"Input data bit size should be larger than that type requires");
- auto &impl = type->getContext()->getImpl();
+ auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined.
DenseElementsAttrInfo::KeyTy key({type, data});
@@ -1048,8 +1061,8 @@
return *existing.first;
// Otherwise, allocate a new one, unique it and return it.
- auto *eltType = type->getElementType();
- switch (eltType->getKind()) {
+ auto eltType = type.getElementType();
+ switch (eltType.getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
@@ -1064,7 +1077,7 @@
return *existing.first = result;
}
case Type::Kind::Integer: {
- auto width = ::cast<IntegerType>(eltType)->getWidth();
+ auto width = eltType.cast<IntegerType>().getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
@@ -1080,12 +1093,12 @@
}
}
-OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
+OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
StringRef bytes) {
- assert(isValidTensorElementType(type->getElementType()) &&
+ assert(isValidTensorElementType(type.getElementType()) &&
"Input element type should be a valid tensor element type");
- auto &impl = type->getContext()->getImpl();
+ auto &impl = type.getContext()->getImpl();
// Look to see if this constant is already defined.
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
@@ -1104,10 +1117,10 @@
return *existing.first = result;
}
-SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
+SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
- auto &impl = type->getContext()->getImpl();
+ auto &impl = type.getContext()->getImpl();
// Look to see if we already have this.
auto key = std::make_tuple(type, indices, values);
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index 2ed09b8..0722421 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -377,7 +377,7 @@
}
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
- auto *type = op->getResult(0)->getType();
+ auto type = op->getResult(0)->getType();
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
if (op->getResult(i)->getType() != type)
return op->emitOpError(
@@ -393,19 +393,19 @@
/// If this is a vector type, or a tensor type, return the scalar element type
/// that it is built around, otherwise return the type unmodified.
-static Type *getTensorOrVectorElementType(Type *type) {
- if (auto *vec = dyn_cast<VectorType>(type))
- return vec->getElementType();
+static Type getTensorOrVectorElementType(Type type) {
+ if (auto vec = type.dyn_cast<VectorType>())
+ return vec.getElementType();
// Look through tensor<vector<...>> to find the underlying element type.
- if (auto *tensor = dyn_cast<TensorType>(type))
- return getTensorOrVectorElementType(tensor->getElementType());
+ if (auto tensor = type.dyn_cast<TensorType>())
+ return getTensorOrVectorElementType(tensor.getElementType());
return type;
}
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
for (auto *result : op->getResults()) {
- if (!isa<FloatType>(getTensorOrVectorElementType(result->getType())))
+ if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
return op->emitOpError("requires a floating point type");
}
@@ -414,7 +414,7 @@
bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
for (auto *result : op->getResults()) {
- if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType())))
+ if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
return op->emitOpError("requires an integer type");
}
return false;
@@ -436,7 +436,7 @@
bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
- Type *type;
+ Type type;
return parser->parseOperandList(ops, 2) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
@@ -448,7 +448,7 @@
*p << op->getName() << ' ' << *op->getOperand(0) << ", "
<< *op->getOperand(1);
p->printOptionalAttrDict(op->getAttrs());
- *p << " : " << *op->getResult(0)->getType();
+ *p << " : " << op->getResult(0)->getType();
}
//===----------------------------------------------------------------------===//
@@ -456,14 +456,14 @@
//===----------------------------------------------------------------------===//
void impl::buildCastOp(Builder *builder, OperationState *result,
- SSAValue *source, Type *destType) {
+ SSAValue *source, Type destType) {
result->addOperands(source);
result->addTypes(destType);
}
bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcInfo;
- Type *srcType, *dstType;
+ Type srcType, dstType;
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
parser->resolveOperand(srcInfo, srcType, result->operands) ||
parser->parseKeywordType("to", dstType) ||
@@ -472,5 +472,5 @@
void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
*p << op->getName() << ' ' << *op->getOperand(0) << " : "
- << *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType();
+ << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index e9c46d6..698089a1 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -239,7 +239,7 @@
/// Create a new OperationStmt with the specific fields.
OperationStmt *OperationStmt::create(Location *location, OperationName name,
ArrayRef<MLValue *> operands,
- ArrayRef<Type *> resultTypes,
+ ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes,
MLIRContext *context) {
auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
@@ -288,9 +288,9 @@
// If we have a result or operand type, that is a constant time way to get
// to the context.
if (getNumResults())
- return getResult(0)->getType()->getContext();
+ return getResult(0)->getType().getContext();
if (getNumOperands())
- return getOperand(0)->getType()->getContext();
+ return getOperand(0)->getType().getContext();
// In the very odd case where we have no operands or results, fall back to
// doing a find.
@@ -474,7 +474,7 @@
if (operands.empty())
return findFunction()->getContext();
- return getOperand(0)->getType()->getContext();
+ return getOperand(0)->getType().getContext();
}
//===----------------------------------------------------------------------===//
@@ -501,7 +501,7 @@
operands.push_back(remapOperand(opValue));
if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
- SmallVector<Type *, 8> resultTypes;
+ SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
diff --git a/lib/IR/TypeDetail.h b/lib/IR/TypeDetail.h
new file mode 100644
index 0000000..c22e87a
--- /dev/null
+++ b/lib/IR/TypeDetail.h
@@ -0,0 +1,126 @@
+//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of Type.
+//
+//===----------------------------------------------------------------------===//
+#ifndef TYPEDETAIL_H_
+#define TYPEDETAIL_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+
+class AffineMap;
+class MLIRContext;
+
+namespace detail {
+
+/// Base storage class appearing in a Type.
+struct alignas(8) TypeStorage {
+ TypeStorage(Type::Kind kind, MLIRContext *context)
+ : context(context), kind(kind), subclassData(0) {}
+ TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData)
+ : context(context), kind(kind), subclassData(subclassData) {}
+
+ unsigned getSubclassData() const { return subclassData; }
+
+ void setSubclassData(unsigned val) {
+ subclassData = val;
+ // Ensure we don't have any accidental truncation.
+ assert(getSubclassData() == val && "Subclass data too large for field");
+ }
+
+ /// This refers to the MLIRContext in which this type was uniqued.
+ MLIRContext *const context;
+
+ /// Classification of the subclass, used for type checking.
+ Type::Kind kind : 8;
+
+ /// Space for subclasses to store data.
+ unsigned subclassData : 24;
+};
+
+struct IntegerTypeStorage : public TypeStorage {
+ unsigned width;
+};
+
+struct FloatTypeStorage : public TypeStorage {};
+
+struct OtherTypeStorage : public TypeStorage {};
+
+struct FunctionTypeStorage : public TypeStorage {
+ ArrayRef<Type> getInputs() const {
+ return ArrayRef<Type>(inputsAndResults, subclassData);
+ }
+ ArrayRef<Type> getResults() const {
+ return ArrayRef<Type>(inputsAndResults + subclassData, numResults);
+ }
+
+ unsigned numResults;
+ Type const *inputsAndResults;
+};
+
+struct VectorOrTensorTypeStorage : public TypeStorage {
+ Type elementType;
+};
+
+struct VectorTypeStorage : public VectorOrTensorTypeStorage {
+ ArrayRef<int> getShape() const {
+ return ArrayRef<int>(shapeElements, getSubclassData());
+ }
+
+ const int *shapeElements;
+};
+
+struct TensorTypeStorage : public VectorOrTensorTypeStorage {};
+
+struct RankedTensorTypeStorage : public TensorTypeStorage {
+ ArrayRef<int> getShape() const {
+ return ArrayRef<int>(shapeElements, getSubclassData());
+ }
+
+ const int *shapeElements;
+};
+
+struct UnrankedTensorTypeStorage : public TensorTypeStorage {};
+
+struct MemRefTypeStorage : public TypeStorage {
+ ArrayRef<int> getShape() const {
+ return ArrayRef<int>(shapeElements, getSubclassData());
+ }
+
+ ArrayRef<AffineMap> getAffineMaps() const {
+ return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+ }
+
+ /// The type of each scalar element of the memref.
+ Type elementType;
+ /// An array of integers which stores the shape dimension sizes.
+ const int *shapeElements;
+ /// The number of affine maps in the 'affineMapList' array.
+ const unsigned numAffineMaps;
+ /// List of affine maps in the memref's layout/index map composition.
+ AffineMap const *affineMapList;
+ /// Memory space in which data referenced by memref resides.
+ const unsigned memorySpace;
+};
+
+} // namespace detail
+} // namespace mlir
+#endif // TYPEDETAIL_H_
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 0ad3f47..1a71695 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -16,10 +16,17 @@
// =============================================================================
#include "mlir/IR/Types.h"
+#include "TypeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
+
using namespace mlir;
+using namespace mlir::detail;
+
+Type::Kind Type::getKind() const { return type->kind; }
+
+MLIRContext *Type::getContext() const { return type->context; }
unsigned Type::getBitWidth() const {
switch (getKind()) {
@@ -32,34 +39,49 @@
case Type::Kind::F64:
return 64;
case Type::Kind::Integer:
- return cast<IntegerType>(this)->getWidth();
+ return cast<IntegerType>().getWidth();
case Type::Kind::Vector:
case Type::Kind::RankedTensor:
case Type::Kind::UnrankedTensor:
- return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth();
+ return cast<VectorOrTensorType>().getElementType().getBitWidth();
// TODO: Handle more types.
default:
llvm_unreachable("unexpected type");
}
}
-IntegerType::IntegerType(unsigned width, MLIRContext *context)
- : Type(Kind::Integer, context), width(width) {
- assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
+unsigned Type::getSubclassData() const { return type->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+
+IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
+
+unsigned IntegerType::getWidth() const {
+ return static_cast<ImplType *>(type)->width;
}
-FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
+FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
-OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
+OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
-FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
- unsigned numResults, MLIRContext *context)
- : Type(Kind::Function, context, numInputs), numResults(numResults),
- inputsAndResults(inputsAndResults) {}
+FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
-VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
- Type *elementType, unsigned subClassData)
- : Type(kind, context, subClassData), elementType(elementType) {}
+ArrayRef<Type> FunctionType::getInputs() const {
+ return static_cast<ImplType *>(type)->getInputs();
+}
+
+unsigned FunctionType::getNumResults() const {
+ return static_cast<ImplType *>(type)->numResults;
+}
+
+ArrayRef<Type> FunctionType::getResults() const {
+ return static_cast<ImplType *>(type)->getResults();
+}
+
+VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {}
+
+Type VectorOrTensorType::getElementType() const {
+ return static_cast<ImplType *>(type)->elementType;
+}
unsigned VectorOrTensorType::getNumElements() const {
switch (getKind()) {
@@ -103,11 +125,11 @@
ArrayRef<int> VectorOrTensorType::getShape() const {
switch (getKind()) {
case Kind::Vector:
- return cast<VectorType>(this)->getShape();
+ return cast<VectorType>().getShape();
case Kind::RankedTensor:
- return cast<RankedTensorType>(this)->getShape();
+ return cast<RankedTensorType>().getShape();
case Kind::UnrankedTensor:
- return cast<RankedTensorType>(this)->getShape();
+ return cast<RankedTensorType>().getShape();
default:
llvm_unreachable("not a VectorOrTensorType");
}
@@ -118,35 +140,38 @@
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
}
-VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
- MLIRContext *context)
- : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
- shapeElements(shape.data()) {}
+VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
-TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
- : VectorOrTensorType(kind, context, elementType) {
- assert(isValidTensorElementType(elementType));
+ArrayRef<int> VectorType::getShape() const {
+ return static_cast<ImplType *>(type)->getShape();
}
-RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
- MLIRContext *context)
- : TensorType(Kind::RankedTensor, elementType, context),
- shapeElements(shape.data()) {
- setSubclassData(shape.size());
+TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
+
+RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
+
+ArrayRef<int> RankedTensorType::getShape() const {
+ return static_cast<ImplType *>(type)->getShape();
}
-UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
- : TensorType(Kind::UnrankedTensor, elementType, context) {}
+UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
-MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
- ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
- MLIRContext *context)
- : Type(Kind::MemRef, context, shape.size()), elementType(elementType),
- shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
- affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
+MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {}
+
+ArrayRef<int> MemRefType::getShape() const {
+ return static_cast<ImplType *>(type)->getShape();
+}
+
+Type MemRefType::getElementType() const {
+ return static_cast<ImplType *>(type)->elementType;
+}
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
- return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+ return static_cast<ImplType *>(type)->getAffineMaps();
+}
+
+unsigned MemRefType::getMemorySpace() const {
+ return static_cast<ImplType *>(type)->memorySpace;
}
unsigned MemRefType::getNumDynamicDims() const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 7974c7c..ceb8931 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -182,19 +182,19 @@
// as the results of their action.
// Type parsing.
- VectorType *parseVectorType();
+ VectorType parseVectorType();
ParseResult parseXInDimensionList();
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
- Type *parseTensorType();
- Type *parseMemRefType();
- Type *parseFunctionType();
- Type *parseType();
- ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
- ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
+ Type parseTensorType();
+ Type parseMemRefType();
+ Type parseFunctionType();
+ Type parseType();
+ ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
+ ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
// Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
- FunctionType *type);
+ FunctionType type);
Attribute parseAttribute();
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@@ -206,9 +206,9 @@
AffineMap parseAffineMapReference();
IntegerSet parseIntegerSetInline();
IntegerSet parseIntegerSetReference();
- DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
- DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
- VectorOrTensorType *parseVectorOrTensorType();
+ DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type);
+ DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector);
+ VectorOrTensorType parseVectorOrTensorType();
private:
// The Parser is subclassed and reinstantiated. Do not add additional
@@ -299,7 +299,7 @@
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
/// other-type ::= `index` | `tf_control`
///
-Type *Parser::parseType() {
+Type Parser::parseType() {
switch (getToken().getKind()) {
default:
return (emitError("expected type"), nullptr);
@@ -368,7 +368,7 @@
/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
/// const-dimension-list ::= (integer-literal `x`)+
///
-VectorType *Parser::parseVectorType() {
+VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector);
if (parseToken(Token::less, "expected '<' in vector type"))
@@ -402,11 +402,11 @@
// Parse the element type.
auto typeLoc = getToken().getLoc();
- auto *elementType = parseType();
+ auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
+ if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
return (emitError(typeLoc, "invalid vector element type"), nullptr);
return VectorType::get(dimensions, elementType);
@@ -461,7 +461,7 @@
/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
/// dimension-list ::= dimension-list-ranked | `*x`
///
-Type *Parser::parseTensorType() {
+Type Parser::parseTensorType() {
consumeToken(Token::kw_tensor);
if (parseToken(Token::less, "expected '<' in tensor type"))
@@ -485,7 +485,7 @@
// Parse the element type.
auto typeLoc = getToken().getLoc();
- auto *elementType = parseType();
+ auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr;
@@ -505,7 +505,7 @@
/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
/// memory-space ::= integer-literal /* | TODO: address-space-id */
///
-Type *Parser::parseMemRefType() {
+Type Parser::parseMemRefType() {
consumeToken(Token::kw_memref);
if (parseToken(Token::less, "expected '<' in memref type"))
@@ -517,12 +517,12 @@
// Parse the element type.
auto typeLoc = getToken().getLoc();
- auto *elementType = parseType();
+ auto elementType = parseType();
if (!elementType)
return nullptr;
- if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
- !isa<VectorType>(elementType))
+ if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() &&
+ !elementType.isa<VectorType>())
return (emitError(typeLoc, "invalid memref element type"), nullptr);
// Parse semi-affine-map-composition.
@@ -581,10 +581,10 @@
///
/// function-type ::= type-list-parens `->` type-list
///
-Type *Parser::parseFunctionType() {
+Type Parser::parseFunctionType() {
assert(getToken().is(Token::l_paren));
- SmallVector<Type *, 4> arguments, results;
+ SmallVector<Type, 4> arguments, results;
if (parseTypeList(arguments) ||
parseToken(Token::arrow, "expected '->' in function type") ||
parseTypeList(results))
@@ -598,7 +598,7 @@
///
/// type-list-no-parens ::= type (`,` type)*
///
-ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult {
auto elt = parseType();
elements.push_back(elt);
@@ -615,7 +615,7 @@
/// type-list-parens ::= `(` `)`
/// | `(` type-list-no-parens `)`
///
-ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
+ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult {
auto elt = parseType();
elements.push_back(elt);
@@ -639,8 +639,8 @@
namespace {
class TensorLiteralParser {
public:
- TensorLiteralParser(Parser &p, Type *eltTy)
- : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {}
+ TensorLiteralParser(Parser &p, Type eltTy)
+ : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
ParseResult parse() { return parseList(shape); }
@@ -676,7 +676,7 @@
}
Parser &p;
- Type *eltTy;
+ Type eltTy;
size_t currBitPos;
size_t bitsWidth;
SmallVector<int, 4> shape;
@@ -698,7 +698,7 @@
if (!result)
return p.emitError("expected tensor element");
// check result matches the element type.
- switch (eltTy->getKind()) {
+ switch (eltTy.getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
@@ -779,7 +779,7 @@
/// synthesizing a forward reference) or emit an error and return null on
/// failure.
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
- FunctionType *type) {
+ FunctionType type) {
Identifier name = builder.getIdentifier(nameStr.drop_front());
// See if the function has already been defined in the module.
@@ -902,10 +902,10 @@
if (parseToken(Token::colon, "expected ':' and function type"))
return nullptr;
auto typeLoc = getToken().getLoc();
- Type *type = parseType();
+ Type type = parseType();
if (!type)
return nullptr;
- auto *fnType = dyn_cast<FunctionType>(type);
+ auto fnType = type.dyn_cast<FunctionType>();
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
@@ -916,7 +916,7 @@
consumeToken(Token::kw_opaque);
if (parseToken(Token::less, "expected '<' after 'opaque'"))
return nullptr;
- auto *type = parseVectorOrTensorType();
+ auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
auto val = getToken().getStringValue();
@@ -937,7 +937,7 @@
if (parseToken(Token::less, "expected '<' after 'splat'"))
return nullptr;
- auto *type = parseVectorOrTensorType();
+ auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
switch (getToken().getKind()) {
@@ -959,7 +959,7 @@
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
- auto *type = parseVectorOrTensorType();
+ auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
@@ -981,41 +981,41 @@
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
- auto *type = parseVectorOrTensorType();
+ auto type = parseVectorOrTensorType();
if (!type)
return nullptr;
switch (getToken().getKind()) {
case Token::l_square: {
/// Parse indices
- auto *indicesEltType = builder.getIntegerType(32);
+ auto indicesEltType = builder.getIntegerType(32);
auto indices =
- parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
+ parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
if (parseToken(Token::comma, "expected ','"))
return nullptr;
/// Parse values.
- auto *valuesEltType = type->getElementType();
+ auto valuesEltType = type.getElementType();
auto values =
- parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
+ parseDenseElementsAttr(valuesEltType, type.isa<VectorType>());
/// Sanity check.
- auto *indicesType = indices.getType();
- auto *valuesType = values.getType();
- auto sameShape = (indicesType->getRank() == 1) ||
- (type->getRank() == indicesType->getDimSize(1));
+ auto indicesType = indices.getType();
+ auto valuesType = values.getType();
+ auto sameShape = (indicesType.getRank() == 1) ||
+ (type.getRank() == indicesType.getDimSize(1));
auto sameElementNum =
- indicesType->getDimSize(0) == valuesType->getDimSize(0);
+ indicesType.getDimSize(0) == valuesType.getDimSize(0);
if (!sameShape || !sameElementNum) {
std::string str;
llvm::raw_string_ostream s(str);
s << "expected shape ([";
- interleaveComma(type->getShape(), s);
+ interleaveComma(type.getShape(), s);
s << "]); inferred shape of indices literal ([";
- interleaveComma(indicesType->getShape(), s);
+ interleaveComma(indicesType.getShape(), s);
s << "]); inferred shape of values literal ([";
- interleaveComma(valuesType->getShape(), s);
+ interleaveComma(valuesType.getShape(), s);
s << "])";
return (emitError(s.str()), nullptr);
}
@@ -1035,7 +1035,7 @@
nullptr);
}
default: {
- if (Type *type = parseType())
+ if (Type type = parseType())
return builder.getTypeAttr(type);
return nullptr;
}
@@ -1051,12 +1051,12 @@
///
/// This method returns a constructed dense elements attribute with the shape
/// from the parsing result.
-DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
+DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) {
TensorLiteralParser literalParser(*this, eltType);
if (literalParser.parse())
return nullptr;
- VectorOrTensorType *type;
+ VectorOrTensorType type;
if (isVector) {
type = builder.getVectorType(literalParser.getShape(), eltType);
} else {
@@ -1076,18 +1076,18 @@
/// This method compares the shapes from the parsing result and that from the
/// input argument. It returns a constructed dense elements attribute if both
/// match.
-DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
- auto *eltTy = type->getElementType();
+DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
+ auto eltTy = type.getElementType();
TensorLiteralParser literalParser(*this, eltTy);
if (literalParser.parse())
return nullptr;
- if (literalParser.getShape() != type->getShape()) {
+ if (literalParser.getShape() != type.getShape()) {
std::string str;
llvm::raw_string_ostream s(str);
s << "inferred shape of elements literal ([";
interleaveComma(literalParser.getShape(), s);
s << "]) does not match type ([";
- interleaveComma(type->getShape(), s);
+ interleaveComma(type.getShape(), s);
s << "])";
return (emitError(s.str()), nullptr);
}
@@ -1100,8 +1100,8 @@
/// vector-or-tensor-type ::= vector-type | tensor-type
///
/// This method also checks the type has static shape and ranked.
-VectorOrTensorType *Parser::parseVectorOrTensorType() {
- auto *type = dyn_cast<VectorOrTensorType>(parseType());
+VectorOrTensorType Parser::parseVectorOrTensorType() {
+ auto type = parseType().dyn_cast<VectorOrTensorType>();
if (!type) {
return (emitError("expected elements literal has a tensor or vector type"),
nullptr);
@@ -1110,7 +1110,7 @@
if (parseToken(Token::comma, "expected ','"))
return nullptr;
- if (!type->hasStaticShape() || type->getRank() == -1) {
+ if (!type.hasStaticShape() || type.getRank() == -1) {
return (emitError("tensor literals must be ranked and have static shape"),
nullptr);
}
@@ -1834,7 +1834,7 @@
/// Given a reference to an SSA value and its type, return a reference. This
/// returns null on failure.
- SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type);
+ SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
/// Register a definition of a value with the symbol table.
ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
@@ -1845,11 +1845,11 @@
template <typename ResultType>
ResultType parseSSADefOrUseAndType(
- const std::function<ResultType(SSAUseInfo, Type *)> &action);
+ const std::function<ResultType(SSAUseInfo, Type)> &action);
SSAValue *parseSSAUseAndType() {
return parseSSADefOrUseAndType<SSAValue *>(
- [&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
+ [&](SSAUseInfo useInfo, Type type) -> SSAValue * {
return resolveSSAUse(useInfo, type);
});
}
@@ -1880,7 +1880,7 @@
/// their first reference, to allow checking for use of undefined values.
DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
- SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type);
+ SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
/// Return true if this is a forward reference.
bool isForwardReferencePlaceholder(SSAValue *value) {
@@ -1891,7 +1891,7 @@
/// Create and remember a new placeholder for a forward reference.
SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
- Type *type) {
+ Type type) {
// Forward references are always created as instructions, even in ML
// functions, because we just need something with a def/use chain.
//
@@ -1908,7 +1908,7 @@
/// Given an unbound reference to an SSA value and its type, return the value
/// it specifies. This returns null on failure.
-SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
+SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
auto &entries = values[useInfo.name];
// If we have already seen a value of this name, return it.
@@ -2057,14 +2057,14 @@
/// ssa-use-and-type ::= ssa-use `:` type
template <typename ResultType>
ResultType FunctionParser::parseSSADefOrUseAndType(
- const std::function<ResultType(SSAUseInfo, Type *)> &action) {
+ const std::function<ResultType(SSAUseInfo, Type)> &action) {
SSAUseInfo useInfo;
if (parseSSAUse(useInfo) ||
parseToken(Token::colon, "expected ':' and type for SSA operand"))
return nullptr;
- auto *type = parseType();
+ auto type = parseType();
if (!type)
return nullptr;
@@ -2101,7 +2101,7 @@
if (valueIDs.empty())
return ParseSuccess;
- SmallVector<Type *, 4> types;
+ SmallVector<Type, 4> types;
if (parseToken(Token::colon, "expected ':' in operand list") ||
parseTypeListNoParens(types))
return ParseFailure;
@@ -2209,14 +2209,14 @@
auto type = parseType();
if (!type)
return nullptr;
- auto fnType = dyn_cast<FunctionType>(type);
+ auto fnType = type.dyn_cast<FunctionType>();
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
- result.addTypes(fnType->getResults());
+ result.addTypes(fnType.getResults());
// Check that we have the right number of types for the operands.
- auto operandTypes = fnType->getInputs();
+ auto operandTypes = fnType.getInputs();
if (operandTypes.size() != operandInfos.size()) {
auto plural = "s"[operandInfos.size() == 1];
return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
@@ -2253,17 +2253,17 @@
return parser.parseToken(Token::comma, "expected ','");
}
- bool parseColonType(Type *&result) override {
+ bool parseColonType(Type &result) override {
return parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType());
}
- bool parseColonTypeList(SmallVectorImpl<Type *> &result) override {
+ bool parseColonTypeList(SmallVectorImpl<Type> &result) override {
if (parser.parseToken(Token::colon, "expected ':'"))
return true;
do {
- if (auto *type = parser.parseType())
+ if (auto type = parser.parseType())
result.push_back(type);
else
return true;
@@ -2273,7 +2273,7 @@
}
/// Parse a keyword followed by a type.
- bool parseKeywordType(const char *keyword, Type *&result) override {
+ bool parseKeywordType(const char *keyword, Type &result) override {
if (parser.getTokenSpelling() != keyword)
return parser.emitError("expected '" + Twine(keyword) + "'");
parser.consumeToken();
@@ -2396,7 +2396,7 @@
}
/// Resolve a parse function name and a type into a function reference.
- virtual bool resolveFunctionName(StringRef name, FunctionType *type,
+ virtual bool resolveFunctionName(StringRef name, FunctionType type,
llvm::SMLoc loc, Function *&result) {
result = parser.resolveFunctionReference(name, loc, type);
return result == nullptr;
@@ -2410,7 +2410,7 @@
llvm::SMLoc getNameLoc() const override { return nameLoc; }
- bool resolveOperand(const OperandType &operand, Type *type,
+ bool resolveOperand(const OperandType &operand, Type type,
SmallVectorImpl<SSAValue *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
@@ -2559,11 +2559,11 @@
return ParseSuccess;
return parseCommaSeparatedList([&]() -> ParseResult {
- auto type = parseSSADefOrUseAndType<Type *>(
- [&](SSAUseInfo useInfo, Type *type) -> Type * {
+ auto type = parseSSADefOrUseAndType<Type>(
+ [&](SSAUseInfo useInfo, Type type) -> Type {
BBArgument *arg = owner->addArgument(type);
if (addDefinition(useInfo, arg))
- return nullptr;
+ return {};
return type;
});
return type ? ParseSuccess : ParseFailure;
@@ -2908,7 +2908,7 @@
" symbol count must match");
// Resolve SSA uses.
- Type *indexType = builder.getIndexType();
+ Type indexType = builder.getIndexType();
for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
if (!sval)
@@ -3187,9 +3187,9 @@
ParseResult parseAffineStructureDef();
// Functions.
- ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+ ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames);
- ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
+ ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
SmallVectorImpl<StringRef> *argNames);
ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
ParseResult parseExtFunc();
@@ -3248,7 +3248,7 @@
/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
///
ParseResult
-ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames) {
consumeToken(Token::l_paren);
@@ -3284,7 +3284,7 @@
/// type-list)?
///
ParseResult
-ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
+ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
SmallVectorImpl<StringRef> *argNames) {
if (getToken().isNot(Token::at_identifier))
return emitError("expected a function identifier like '@foo'");
@@ -3295,7 +3295,7 @@
if (getToken().isNot(Token::l_paren))
return emitError("expected '(' in function signature");
- SmallVector<Type *, 4> argTypes;
+ SmallVector<Type, 4> argTypes;
ParseResult parseResult;
if (argNames)
@@ -3307,7 +3307,7 @@
return ParseFailure;
// Parse the return type if present.
- SmallVector<Type *, 4> results;
+ SmallVector<Type, 4> results;
if (consumeIf(Token::arrow)) {
if (parseTypeList(results))
return ParseFailure;
@@ -3340,7 +3340,7 @@
auto loc = getToken().getLoc();
StringRef name;
- FunctionType *type = nullptr;
+ FunctionType type;
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure;
@@ -3372,7 +3372,7 @@
auto loc = getToken().getLoc();
StringRef name;
- FunctionType *type = nullptr;
+ FunctionType type;
if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
return ParseFailure;
@@ -3405,7 +3405,7 @@
consumeToken(Token::kw_mlfunc);
StringRef name;
- FunctionType *type = nullptr;
+ FunctionType type;
SmallVector<StringRef, 4> argNames;
auto loc = getToken().getLoc();
diff --git a/lib/StandardOps/StandardOps.cpp b/lib/StandardOps/StandardOps.cpp
index b60d209..e2bdfd7 100644
--- a/lib/StandardOps/StandardOps.cpp
+++ b/lib/StandardOps/StandardOps.cpp
@@ -138,23 +138,23 @@
//===----------------------------------------------------------------------===//
void AllocOp::build(Builder *builder, OperationState *result,
- MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
+ MemRefType memrefType, ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->types.push_back(memrefType);
}
void AllocOp::print(OpAsmPrinter *p) const {
- MemRefType *type = getType();
+ MemRefType type = getType();
*p << "alloc";
// Print dynamic dimension operands.
printDimAndSymbolList(operand_begin(), operand_end(),
- type->getNumDynamicDims(), p);
+ type.getNumDynamicDims(), p);
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
- *p << " : " << *type;
+ *p << " : " << type;
}
bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
- MemRefType *type;
+ MemRefType type;
// Parse the dimension operands and optional symbol operands, followed by a
// memref type.
@@ -170,7 +170,7 @@
// Verification still checks that the total number of operands matches
// the number of symbols in the affine map, plus the number of dynamic
// dimensions in the memref.
- if (numDimOperands != type->getNumDynamicDims()) {
+ if (numDimOperands != type.getNumDynamicDims()) {
return parser->emitError(parser->getNameLoc(),
"dimension operand count does not equal memref "
"dynamic dimension count");
@@ -180,13 +180,13 @@
}
bool AllocOp::verify() const {
- auto *memRefType = dyn_cast<MemRefType>(getResult()->getType());
+ auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("result must be a memref");
unsigned numSymbols = 0;
- if (!memRefType->getAffineMaps().empty()) {
- AffineMap affineMap = memRefType->getAffineMaps()[0];
+ if (!memRefType.getAffineMaps().empty()) {
+ AffineMap affineMap = memRefType.getAffineMaps()[0];
// Store number of symbols used in affine map (used in subsequent check).
numSymbols = affineMap.getNumSymbols();
// TODO(zinenko): this check does not belong to AllocOp, or any other op but
@@ -195,10 +195,10 @@
// Remove when we can emit errors directly from *Type::get(...) functions.
//
// Verify that the layout affine map matches the rank of the memref.
- if (affineMap.getNumDims() != memRefType->getRank())
+ if (affineMap.getNumDims() != memRefType.getRank())
return emitOpError("affine map dimension count must equal memref rank");
}
- unsigned numDynamicDims = memRefType->getNumDynamicDims();
+ unsigned numDynamicDims = memRefType.getNumDynamicDims();
// Check that the total number of operands matches the number of symbols in
// the affine map, plus the number of dynamic dimensions specified in the
// memref type.
@@ -208,7 +208,7 @@
}
// Verify that all operands are of type Index.
for (auto *operand : getOperands()) {
- if (!operand->getType()->isIndex())
+ if (!operand->getType().isIndex())
return emitOpError("requires operands to be of type Index");
}
return false;
@@ -239,13 +239,13 @@
// Ok, we have one or more constant operands. Collect the non-constant ones
// and keep track of the resultant memref type to build.
SmallVector<int, 4> newShapeConstants;
- newShapeConstants.reserve(memrefType->getRank());
+ newShapeConstants.reserve(memrefType.getRank());
SmallVector<SSAValue *, 4> newOperands;
SmallVector<SSAValue *, 4> droppedOperands;
unsigned dynamicDimPos = 0;
- for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
- int dimSize = memrefType->getDimSize(dim);
+ for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
+ int dimSize = memrefType.getDimSize(dim);
// If this is already static dimension, keep it.
if (dimSize != -1) {
newShapeConstants.push_back(dimSize);
@@ -267,10 +267,10 @@
}
// Create new memref type (which will have fewer dynamic dimensions).
- auto *newMemRefType = MemRefType::get(
- newShapeConstants, memrefType->getElementType(),
- memrefType->getAffineMaps(), memrefType->getMemorySpace());
- assert(newOperands.size() == newMemRefType->getNumDynamicDims());
+ auto newMemRefType = MemRefType::get(
+ newShapeConstants, memrefType.getElementType(),
+ memrefType.getAffineMaps(), memrefType.getMemorySpace());
+ assert(newOperands.size() == newMemRefType.getNumDynamicDims());
// Create and insert the alloc op for the new memref.
auto newAlloc =
@@ -297,13 +297,13 @@
ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
- result->addTypes(callee->getType()->getResults());
+ result->addTypes(callee->getType().getResults());
}
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
StringRef calleeName;
llvm::SMLoc calleeLoc;
- FunctionType *calleeType = nullptr;
+ FunctionType calleeType;
SmallVector<OpAsmParser::OperandType, 4> operands;
Function *callee = nullptr;
if (parser->parseFunctionName(calleeName, calleeLoc) ||
@@ -312,8 +312,8 @@
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
- parser->addTypesToList(calleeType->getResults(), result->types) ||
- parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
+ parser->addTypesToList(calleeType.getResults(), result->types) ||
+ parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
result->operands))
return true;
@@ -328,7 +328,7 @@
p->printOperands(getOperands());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
- *p << " : " << *getCallee()->getType();
+ *p << " : " << getCallee()->getType();
}
bool CallOp::verify() const {
@@ -338,20 +338,20 @@
return emitOpError("requires a 'callee' function attribute");
// Verify that the operand and result types match the callee.
- auto *fnType = fnAttr.getValue()->getType();
- if (fnType->getNumInputs() != getNumOperands())
+ auto fnType = fnAttr.getValue()->getType();
+ if (fnType.getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee");
- for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
- if (getOperand(i)->getType() != fnType->getInput(i))
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+ if (getOperand(i)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch");
}
- if (fnType->getNumResults() != getNumResults())
+ if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
- for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
- if (getResult(i)->getType() != fnType->getResult(i))
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+ if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
}
@@ -364,14 +364,14 @@
void CallIndirectOp::build(Builder *builder, OperationState *result,
SSAValue *callee, ArrayRef<SSAValue *> operands) {
- auto *fnType = cast<FunctionType>(callee->getType());
+ auto fnType = callee->getType().cast<FunctionType>();
result->operands.push_back(callee);
result->addOperands(operands);
- result->addTypes(fnType->getResults());
+ result->addTypes(fnType.getResults());
}
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
- FunctionType *calleeType = nullptr;
+ FunctionType calleeType;
OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc;
SmallVector<OpAsmParser::OperandType, 4> operands;
@@ -382,9 +382,9 @@
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveOperand(callee, calleeType, result->operands) ||
- parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
+ parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
result->operands) ||
- parser->addTypesToList(calleeType->getResults(), result->types);
+ parser->addTypesToList(calleeType.getResults(), result->types);
}
void CallIndirectOp::print(OpAsmPrinter *p) const {
@@ -395,29 +395,29 @@
p->printOperands(++operandRange.begin(), operandRange.end());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
- *p << " : " << *getCallee()->getType();
+ *p << " : " << getCallee()->getType();
}
bool CallIndirectOp::verify() const {
// The callee must be a function.
- auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
+ auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
if (!fnType)
return emitOpError("callee must have function type");
// Verify that the operand and result types match the callee.
- if (fnType->getNumInputs() != getNumOperands() - 1)
+ if (fnType.getNumInputs() != getNumOperands() - 1)
return emitOpError("incorrect number of operands for callee");
- for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
- if (getOperand(i + 1)->getType() != fnType->getInput(i))
+ for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+ if (getOperand(i + 1)->getType() != fnType.getInput(i))
return emitOpError("operand type mismatch");
}
- if (fnType->getNumResults() != getNumResults())
+ if (fnType.getNumResults() != getNumResults())
return emitOpError("incorrect number of results for callee");
- for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
- if (getResult(i)->getType() != fnType->getResult(i))
+ for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+ if (getResult(i)->getType() != fnType.getResult(i))
return emitOpError("result type mismatch");
}
@@ -434,19 +434,19 @@
}
void DeallocOp::print(OpAsmPrinter *p) const {
- *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
+ *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
}
bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
- MemRefType *type;
+ MemRefType type;
return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands);
}
bool DeallocOp::verify() const {
- if (!isa<MemRefType>(getMemRef()->getType()))
+ if (!getMemRef()->getType().isa<MemRefType>())
return emitOpError("operand must be a memref");
return false;
}
@@ -472,13 +472,13 @@
void DimOp::print(OpAsmPrinter *p) const {
*p << "dim " << *getOperand() << ", " << getIndex();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
- *p << " : " << *getOperand()->getType();
+ *p << " : " << getOperand()->getType();
}
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr;
- Type *type;
+ Type type;
return parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, "index", result->attributes) ||
@@ -496,15 +496,15 @@
return emitOpError("requires an integer attribute named 'index'");
uint64_t index = (uint64_t)indexAttr.getValue();
- auto *type = getOperand()->getType();
- if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
- if (index >= tensorType->getRank())
+ auto type = getOperand()->getType();
+ if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+ if (index >= tensorType.getRank())
return emitOpError("index is out of range");
- } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
- if (index >= memrefType->getRank())
+ } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
+ if (index >= memrefType.getRank())
return emitOpError("index is out of range");
- } else if (isa<UnrankedTensorType>(type)) {
+ } else if (type.isa<UnrankedTensorType>()) {
// ok, assumed to be in-range.
} else {
return emitOpError("requires an operand with tensor or memref type");
@@ -516,12 +516,12 @@
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
// Constant fold dim when the size along the index referred to is a constant.
- auto *opType = getOperand()->getType();
+ auto opType = getOperand()->getType();
int indexSize = -1;
- if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
- indexSize = tensorType->getShape()[getIndex()];
- } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
- indexSize = memrefType->getShape()[getIndex()];
+ if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
+ indexSize = tensorType.getShape()[getIndex()];
+ } else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
+ indexSize = memrefType.getShape()[getIndex()];
}
if (indexSize >= 0)
@@ -544,9 +544,9 @@
p->printOperands(getTagIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getSrcMemRef()->getType();
- *p << ", " << *getDstMemRef()->getType();
- *p << ", " << *getTagMemRef()->getType();
+ *p << " : " << getSrcMemRef()->getType();
+ *p << ", " << getDstMemRef()->getType();
+ *p << ", " << getTagMemRef()->getType();
}
// Parse DmaStartOp.
@@ -566,8 +566,8 @@
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
- SmallVector<Type *, 3> types;
- auto *indexType = parser->getBuilder().getIndexType();
+ SmallVector<Type, 3> types;
+ auto indexType = parser->getBuilder().getIndexType();
// Parse and resolve the following list of operands:
// *) source memref followed by its indices (in square brackets).
@@ -601,12 +601,12 @@
return true;
// Check that source/destination index list size matches associated rank.
- if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
- dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
+ if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
+ dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"memref rank not equal to indices count");
- if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
+ if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
@@ -632,7 +632,7 @@
p->printOperands(getTagIndices());
*p << "], ";
p->printOperand(getNumElements());
- *p << " : " << *getTagMemRef()->getType();
+ *p << " : " << getTagMemRef()->getType();
}
// Parse DmaWaitOp.
@@ -642,8 +642,8 @@
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
- Type *type;
- auto *indexType = parser->getBuilder().getIndexType();
+ Type type;
+ auto indexType = parser->getBuilder().getIndexType();
OpAsmParser::OperandType numElementsInfo;
// Parse tag memref, its indices, and dma size.
@@ -657,7 +657,7 @@
parser->resolveOperand(numElementsInfo, indexType, result->operands))
return true;
- if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
+ if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
return parser->emitError(parser->getNameLoc(),
"tag memref rank not equal to indices count");
@@ -678,10 +678,10 @@
void ExtractElementOp::build(Builder *builder, OperationState *result,
SSAValue *aggregate,
ArrayRef<SSAValue *> indices) {
- auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+ auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
result->addOperands(aggregate);
result->addOperands(indices);
- result->types.push_back(aggregateType->getElementType());
+ result->types.push_back(aggregateType.getElementType());
}
void ExtractElementOp::print(OpAsmPrinter *p) const {
@@ -689,13 +689,13 @@
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getAggregate()->getType();
+ *p << " : " << getAggregate()->getType();
}
bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType aggregateInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- VectorOrTensorType *type;
+ VectorOrTensorType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(aggregateInfo) ||
@@ -705,26 +705,26 @@
parser->parseColonType(type) ||
parser->resolveOperand(aggregateInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
- parser->addTypeToList(type->getElementType(), result->types);
+ parser->addTypeToList(type.getElementType(), result->types);
}
bool ExtractElementOp::verify() const {
if (getNumOperands() == 0)
return emitOpError("expected an aggregate to index into");
- auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+ auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
if (!aggregateType)
return emitOpError("first operand must be a vector or tensor");
- if (getType() != aggregateType->getElementType())
+ if (getType() != aggregateType.getElementType())
return emitOpError("result type must match element type of aggregate");
for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
+ if (!idx->getType().isIndex())
return emitOpError("index to extract_element must have 'index' type");
// Verify the # indices match if we have a ranked type.
- auto aggregateRank = aggregateType->getRank();
+ auto aggregateRank = aggregateType.getRank();
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
return emitOpError("incorrect number of indices for extract_element");
@@ -737,10 +737,10 @@
void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
ArrayRef<SSAValue *> indices) {
- auto *memrefType = cast<MemRefType>(memref->getType());
+ auto memrefType = memref->getType().cast<MemRefType>();
result->addOperands(memref);
result->addOperands(indices);
- result->types.push_back(memrefType->getElementType());
+ result->types.push_back(memrefType.getElementType());
}
void LoadOp::print(OpAsmPrinter *p) const {
@@ -748,13 +748,13 @@
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getMemRefType();
+ *p << " : " << getMemRefType();
}
bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- MemRefType *type;
+ MemRefType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(memrefInfo) ||
@@ -764,25 +764,25 @@
parser->parseColonType(type) ||
parser->resolveOperand(memrefInfo, type, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
- parser->addTypeToList(type->getElementType(), result->types);
+ parser->addTypeToList(type.getElementType(), result->types);
}
bool LoadOp::verify() const {
if (getNumOperands() == 0)
return emitOpError("expected a memref to load from");
- auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+ auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("first operand must be a memref");
- if (getType() != memRefType->getElementType())
+ if (getType() != memRefType.getElementType())
return emitOpError("result type must match element type of memref");
- if (memRefType->getRank() != getNumOperands() - 1)
+ if (memRefType.getRank() != getNumOperands() - 1)
return emitOpError("incorrect number of indices for load");
for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
+ if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices.
@@ -804,31 +804,31 @@
//===----------------------------------------------------------------------===//
bool MemRefCastOp::verify() const {
- auto *opType = dyn_cast<MemRefType>(getOperand()->getType());
- auto *resType = dyn_cast<MemRefType>(getType());
+ auto opType = getOperand()->getType().dyn_cast<MemRefType>();
+ auto resType = getType().dyn_cast<MemRefType>();
if (!opType || !resType)
return emitOpError("requires input and result types to be memrefs");
if (opType == resType)
return emitOpError("requires the input and result type to be different");
- if (opType->getElementType() != resType->getElementType())
+ if (opType.getElementType() != resType.getElementType())
return emitOpError(
"requires input and result element types to be the same");
- if (opType->getAffineMaps() != resType->getAffineMaps())
+ if (opType.getAffineMaps() != resType.getAffineMaps())
return emitOpError("requires input and result mappings to be the same");
- if (opType->getMemorySpace() != resType->getMemorySpace())
+ if (opType.getMemorySpace() != resType.getMemorySpace())
return emitOpError(
"requires input and result memory spaces to be the same");
// They must have the same rank, and any specified dimensions must match.
- if (opType->getRank() != resType->getRank())
+ if (opType.getRank() != resType.getRank())
return emitOpError("requires input and result ranks to match");
- for (unsigned i = 0, e = opType->getRank(); i != e; ++i) {
- int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i);
+ for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
+ int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
return emitOpError("requires static dimensions to match");
}
@@ -923,14 +923,14 @@
p->printOperands(getIndices());
*p << ']';
p->printOptionalAttrDict(getAttrs());
- *p << " : " << *getMemRefType();
+ *p << " : " << getMemRefType();
}
bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
- MemRefType *memrefType;
+ MemRefType memrefType;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
@@ -939,7 +939,7 @@
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(memrefType) ||
- parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
+ parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
result->operands) ||
parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
parser->resolveOperands(indexInfo, affineIntTy, result->operands);
@@ -950,19 +950,19 @@
return emitOpError("expected a value to store and a memref");
// Second operand is a memref type.
- auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+ auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
if (!memRefType)
return emitOpError("second operand must be a memref");
// First operand must have same type as memref element type.
- if (getValueToStore()->getType() != memRefType->getElementType())
+ if (getValueToStore()->getType() != memRefType.getElementType())
return emitOpError("first operand must have same type memref element type");
- if (getNumOperands() != 2 + memRefType->getRank())
+ if (getNumOperands() != 2 + memRefType.getRank())
return emitOpError("store index operand count not equal to memref rank");
for (auto *idx : getIndices())
- if (!idx->getType()->isIndex())
+ if (!idx->getType().isIndex())
return emitOpError("index to load must have 'index' type");
// TODO: Verify we have the right number of indices.
@@ -1046,31 +1046,31 @@
//===----------------------------------------------------------------------===//
bool TensorCastOp::verify() const {
- auto *opType = dyn_cast<TensorType>(getOperand()->getType());
- auto *resType = dyn_cast<TensorType>(getType());
+ auto opType = getOperand()->getType().dyn_cast<TensorType>();
+ auto resType = getType().dyn_cast<TensorType>();
if (!opType || !resType)
return emitOpError("requires input and result types to be tensors");
if (opType == resType)
return emitOpError("requires the input and result type to be different");
- if (opType->getElementType() != resType->getElementType())
+ if (opType.getElementType() != resType.getElementType())
return emitOpError(
"requires input and result element types to be the same");
// If the source or destination are unranked, then the cast is valid.
- auto *opRType = dyn_cast<RankedTensorType>(opType);
- auto *resRType = dyn_cast<RankedTensorType>(resType);
+ auto opRType = opType.dyn_cast<RankedTensorType>();
+ auto resRType = resType.dyn_cast<RankedTensorType>();
if (!opRType || !resRType)
return false;
// If they are both ranked, they have to have the same rank, and any specified
// dimensions must match.
- if (opRType->getRank() != resRType->getRank())
+ if (opRType.getRank() != resRType.getRank())
return emitOpError("requires input and result ranks to match");
- for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
- int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
+ for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
+ int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
if (opDim != -1 && resultDim != -1 && opDim != resultDim)
return emitOpError("requires static dimensions to match");
}
diff --git a/lib/Transforms/ConstantFold.cpp b/lib/Transforms/ConstantFold.cpp
index 81994dd..15dd89b 100644
--- a/lib/Transforms/ConstantFold.cpp
+++ b/lib/Transforms/ConstantFold.cpp
@@ -31,7 +31,7 @@
SmallVector<SSAValue *, 8> existingConstants;
// Operation statements that were folded and that need to be erased.
std::vector<OperationStmt *> opStmtsToErase;
- using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
+ using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
bool foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants,
@@ -106,7 +106,7 @@
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
auto &inst = *instIt++;
- auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
+ auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
builder.setInsertionPoint(&inst);
return builder.create<ConstantOp>(inst.getLoc(), value, type);
};
@@ -134,7 +134,7 @@
// Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
- auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
+ auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
};
diff --git a/lib/Transforms/PipelineDataTransfer.cpp b/lib/Transforms/PipelineDataTransfer.cpp
index d96d65b..9042181 100644
--- a/lib/Transforms/PipelineDataTransfer.cpp
+++ b/lib/Transforms/PipelineDataTransfer.cpp
@@ -77,23 +77,23 @@
bInner.setInsertionPoint(forStmt, forStmt->begin());
// Doubles the shape with a leading dimension extent of 2.
- auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
+ auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
// Add the leading dimension in the shape for the double buffer.
- ArrayRef<int> shape = oldMemRefType->getShape();
+ ArrayRef<int> shape = oldMemRefType.getShape();
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
shapeSizes.insert(shapeSizes.begin(), 2);
- auto *newMemRefType =
- bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
- oldMemRefType->getMemorySpace());
+ auto newMemRefType =
+ bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {},
+ oldMemRefType.getMemorySpace());
return newMemRefType;
};
- auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
+ auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
// Create and place the alloc at the top level.
MLFuncBuilder topBuilder(forStmt->getFunction());
- auto *newMemRef = cast<MLValue>(
+ auto newMemRef = cast<MLValue>(
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
->getResult());
diff --git a/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cdf5b71..4ec8942 100644
--- a/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -78,7 +78,7 @@
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
- DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
+ DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
};
}; // end anonymous namespace
diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp
index edd8ce8..ad9d6dc 100644
--- a/lib/Transforms/Utils/Utils.cpp
+++ b/lib/Transforms/Utils/Utils.cpp
@@ -52,9 +52,9 @@
MLValue *newMemRef,
ArrayRef<MLValue *> extraIndices,
AffineMap indexRemap) {
- unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
+ unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank; // unused in opt mode
- unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
+ unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
(void)newMemRefRank;
if (indexRemap) {
assert(indexRemap.getNumInputs() == oldMemRefRank);
@@ -64,8 +64,8 @@
}
// Assert same elemental type.
- assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
- cast<MemRefType>(newMemRef->getType())->getElementType());
+ assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
+ newMemRef->getType().cast<MemRefType>().getElementType());
// Check if memref was used in a non-deferencing context.
for (const StmtOperand &use : oldMemRef->getUses()) {
@@ -139,7 +139,7 @@
opStmt->operand_end());
// Result types don't change. Both memref's are of the same elemental type.
- SmallVector<Type *, 8> resultTypes;
+ SmallVector<Type, 8> resultTypes;
resultTypes.reserve(opStmt->getNumResults());
for (const auto *result : opStmt->getResults())
resultTypes.push_back(result->getType());
diff --git a/lib/Transforms/Vectorize.cpp b/lib/Transforms/Vectorize.cpp
index d7a1f53..511afa9 100644
--- a/lib/Transforms/Vectorize.cpp
+++ b/lib/Transforms/Vectorize.cpp
@@ -202,15 +202,15 @@
/// sizes specified by vectorSize. The MemRef lives in the same memory space as
/// tmpl. The MemRef should be promoted to a closer memory address space in a
/// later pass.
-static MemRefType *getVectorizedMemRefType(MemRefType *tmpl,
- ArrayRef<int> vectorSizes) {
- auto *elementType = tmpl->getElementType();
- assert(!dyn_cast<VectorType>(elementType) &&
+static MemRefType getVectorizedMemRefType(MemRefType tmpl,
+ ArrayRef<int> vectorSizes) {
+ auto elementType = tmpl.getElementType();
+ assert(!elementType.dyn_cast<VectorType>() &&
"Can't vectorize an already vector type");
- assert(tmpl->getAffineMaps().empty() &&
+ assert(tmpl.getAffineMaps().empty() &&
"Unsupported non-implicit identity map");
return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
- tmpl->getMemorySpace());
+ tmpl.getMemorySpace());
}
/// Creates an unaligned load with the following semantics:
@@ -258,7 +258,7 @@
operands.insert(operands.end(), dstMemRef);
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
using functional::map;
- std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
+ std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
return v->getType();
};
auto types = map(getType, operands);
@@ -310,7 +310,7 @@
operands.insert(operands.end(), dstMemRef);
operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
using functional::map;
- std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
+ std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
return v->getType();
};
auto types = map(getType, operands);
@@ -348,8 +348,9 @@
template <typename LoadOrStoreOpPointer>
static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
ArrayRef<int> vectorSize) {
- auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
- auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
+ auto memRefType =
+ memoryOp->getMemRef()->getType().template cast<MemRefType>();
+ auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
// Materialize a MemRef with 1 vector.
auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());